我希望我的数据类有一个可以手动提供的字段,如果没有,则在初始化时从其他字段推断。MWE:
from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional
@dataclass
class Foo:
data: Sized
index: Optional[list[int]] = None
def __post_init__(self):
if self.index is None:
self.index = list(range(len(self.data)))
reveal_type(Foo.index) # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index) # Union[None, list[int]]
如何以这样的方式实现:
- 符合
mypy
类型检查 index
保证为list[int]
类型
我考虑过使用default_factory(list)
,但是,如何区分传递index=[]
的用户和sentinel值?除了,还有合适的解决方案吗
index: list[int] = None # type: ignore[assignment]
使用NotImplemented
from collections.abc import Sized
from dataclasses import dataclass
@dataclass
class Foo:
data: Sized
index: list[int] = NotImplemented
def __post_init__(self):
if self.index is NotImplemented:
self.index = list(range(len(self.data)))
您可以让default_factory
返回一个列表,其中sentinel对象是其唯一的元素。您只需要确保sentinel是int
的实例,否则mypy
会抱怨。幸运的是,我们有身份比较,以确保__post_init__
中的检查始终正确。
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
@staticmethod
def _idx_sentinel_factory() -> list[int]:
return [Foo._idx_sentinel]
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))
我把整个工厂和哨兵逻辑放在Foo
的中,但如果你不喜欢,你也可以考虑一下:
from collections.abc import Sized
from dataclasses import dataclass, field
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
def _idx_sentinel_factory() -> list[int]:
return [_idx_sentinel]
@dataclass
class Foo:
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is _idx_sentinel:
self.index = list(range(len(self.data)))
EDIT:受@SUTerliakov评论的启发,这里有一个稍微不那么详细的版本,它仍然使用lambda
表达式而不是命名函数来满足类型检查器和linters的要求:
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
data: Sized
index: list[int] = field(default_factory=lambda: [Foo._idx_sentinel])
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))