在Python数据类中使用描述符作为字段的正确方法是什么?



我一直在使用python数据类,并且想知道:制作一个或一些字段描述符的最优雅或最python的方法是什么?

在下面的例子中,我定义了一个Vector2D类,它应该比较它的长度。

from dataclasses import dataclass, field
from math import sqrt
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)

def __post_init__(self):
type(self).length = property(lambda s: sqrt(s.x**2+s.y**2))
Vector2D(3,4) > Vector2D(4,1) # True
当这段代码工作时,它在每次创建实例时都触及类,是否有一种更可读/更不hack/更有意的

方法来同时使用数据类和描述符?只是将长度作为属性而不是字段将工作,但这意味着我必须写__lt__,等等。by myself.

我发现的另一个解决方案同样不吸引人:

@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)

@property
def length(self):
return sqrt(self.x**2+self.y**2)

@length.setter
def length(self, value):
pass

引入一个无操作setter是必要的,因为显然数据类创建的init方法试图分配长度,即使没有默认值,它显式设置init=False

肯定有更好的方法,对吧?

可能不会回答您的确切问题,但是您提到您不想将长度作为属性和not字段的原因是因为您必须

自己写__lt__

虽然您必须自己实现__lt__,但实际上您可以只实现

from functools import total_ordering
from dataclasses import dataclass, field
from math import sqrt
@total_ordering
@dataclass
class Vector2D:
x: int
y: int
@property
def length(self):
return sqrt(self.x ** 2 + self.y ** 2)
def __lt__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length < other.length
def __eq__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length == other.length

print(Vector2D(3, 4) > Vector2D(4, 1))

之所以有效是因为total_ordering只是添加了基于__eq____lt__的所有其他相等方法

我不认为您提供的示例是您正在尝试做的一个很好的用例。不过,为了完整起见,这里有一个可能的解决方案:

@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(default=property(lambda s: sqrt(s.x**2+s.y**2)), init=False)

这可以工作,因为dataclass将默认值设置为类属性的值,除非该值是list, dict或set。

虽然您可以手动实现@property和其他方法,但如果您想在dict中使用Vector2D,这可能会使您失去其他理想的功能,例如在本例中使用hash=False。此外,让它为你实现dunder方法使你的代码更不容易出错,例如,你不能忘记return NotImplemented,这是一个常见的错误。

缺点是实现正确的类型提示并不容易,而且可能会有一些小的注意事项,但是一旦实现了类型提示,它就可以很容易地在任何地方使用。

属性(描述符)类型提示:

import sys
from typing import Any, Optional, Protocol, TypeVar, overload
if sys.version_info < (3, 9):
from typing import Type
else:
from builtins import type as Type
IT = TypeVar("IT", contravariant=True)
CT = TypeVar("CT", covariant=True)
GT = TypeVar("GT", covariant=True)
ST = TypeVar("ST", contravariant=True)

class Property(Protocol[CT, GT, ST]):
# Get default attribute from a class.
@overload
def __get__(self, instance: None, owner: Type[Any]) -> CT:
...
# Get attribute from an instance.
def __get__(self, instance: IT, owner: Optional[Type[IT]] = ...) -> GT:
...
def __get__(self, instance, owner=None):
...
def __set__(self, instance: Any, value: ST) -> None:
...

从这里开始,我们现在可以在使用dataclass时提示property对象的类型。如果需要使用field(...)中的其他选项,请使用field(default=property(...))

import sys
import typing
from dataclasses import dataclass, field
from math import hypot
# Use for read-only property.
if sys.version_info < (3, 11):
from typing import NoReturn as Never
else:
from typing import Never

@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
# Properties return themselves as their default class variable.
# Read-only properties never allow setting a value.

# If init=True, then it would assign self.length = Vector2D.length for the
# default factory.

# Setting repr=False for consistency with init=False.
length: Property[property, float, Never] = field(
default=property(lambda v: hypot(v.x, v.y)),
init=False,
repr=False,
)

v1 = Vector2D(3, 4)
v2 = Vector2D(6, 8)
if typing.TYPE_CHECKING:
reveal_type(Vector2D.length)  # builtins.property
reveal_type(v1.length)        # builtins.float
assert v1.length == 5.0
assert v2.length == 10.0
assert v1 < v2

在my Playground上试试

最新更新