数据类可选字段,如果缺少则推断该字段



我希望我的数据类有一个可以手动提供的字段,如果没有,则在初始化时从其他字段推断。MWE:

from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional

@dataclass
class Foo:
data: Sized
index: Optional[list[int]] = None
def __post_init__(self):
if self.index is None:
self.index = list(range(len(self.data)))
reveal_type(Foo.index)           # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index)  # Union[None, list[int]]

如何以这样的方式实现:

  1. 符合mypy类型检查
  2. index保证为list[int]类型

我考虑过使用default_factory(list),但是,如何区分传递index=[]的用户和sentinel值?除了,还有合适的解决方案吗

index: list[int] = None  # type: ignore[assignment]

使用NotImplemented

from collections.abc import Sized
from dataclasses import dataclass

@dataclass
class Foo:
data: Sized
index: list[int] = NotImplemented
def __post_init__(self):
if self.index is NotImplemented:
self.index = list(range(len(self.data)))

您可以让default_factory返回一个列表,其中sentinel对象是其唯一的元素。您只需要确保sentinel是int的实例,否则mypy会抱怨。幸运的是,我们有身份比较,以确保__post_init__中的检查始终正确。

from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
@staticmethod
def _idx_sentinel_factory() -> list[int]:
return [Foo._idx_sentinel]
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))

我把整个工厂和哨兵逻辑放在Foo中,但如果你不喜欢,你也可以考虑一下:

from collections.abc import Sized
from dataclasses import dataclass, field
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
def _idx_sentinel_factory() -> list[int]:
return [_idx_sentinel]
@dataclass
class Foo:
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is _idx_sentinel:
self.index = list(range(len(self.data)))

EDIT:受@SUTerliakov评论的启发,这里有一个稍微不那么详细的版本,它仍然使用lambda表达式而不是命名函数来满足类型检查器和linters的要求:

from collections.abc import Sized
from dataclasses import dataclass, field

@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
data: Sized
index: list[int] = field(default_factory=lambda: [Foo._idx_sentinel])
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))

相关内容

  • 没有找到相关文章

最新更新