N-dim数值列表类型



如何在Python 3.7中定义n维数值列表(张量)的类型?这将被用作Pydantic的BaseModel的道具之一。

我想要像

from typing import List, Union
NumericalList = Union[
int, float,
List[int], List[float],
List[List[int]], List[List[float]],
...
]
n1: NumericalList = [0]
n3: NumericalList = [ [ [0, 1, 2], [1, 1, 2] ],
[ [1, 2, 3], [0, 1, 3] ],
]

我知道在类中可以写字符串字面值来指示子类型是相同的。(或者只是加入from __future__ import annotations。)我想要一个iterable/slicable,而不是通过props访问。

原以为递归定义可能会起作用,但typing模块在deepcopy的许多级别之后,AttributeError: __forward_arg__就失败了。

NumericalList = Union[int, float, List["NumericalList"]]  # AttributeError: __forward_arg__
但是,请注意,当使用Pydantic而不是在Python的IDE中使用时,这将失败。这是Pydantic特有的还是一种错误的做法?

虽然我希望pydantic支持本地递归类型,但您可以使用pydantic自定义根类型模型和pydantic严格类型来确保float值不会变成int

from __future__ import annotations
from typing import Union, List
from pydantic import BaseModel, StrictInt, StrictFloat
class NumericalList(BaseModel):
__root__: Union[StrictInt, StrictFloat, List[NumericalList]]

NumericalList.update_forward_refs()

n1: NumericalList = NumericalList.parse_obj([0])
"""
NumericalList(__root__=[NumericalList(__root__=0)])
"""
n2: NumericalList = NumericalList.parse_obj(
[ [ [0, 1, 2], [1, 1, 2] ],
[ [1, 2, 3], [0, 1, 3] ],
]
)
"""
NumericalList(__root__=[
NumericalList(__root__=[
NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=2)]),
NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=1), NumericalList(__root__=2)])
]),
NumericalList(__root__=[
NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=2), NumericalList(__root__=3)]),
NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=3)])
])
])
"""

n1.dict()
"""
{"__root__": [0]}
"""
n2.dict()
"""
{"__root__": [ 
[ [0, 1, 2], [1, 1, 2] ],
[ [1, 2, 3], [0, 1, 3] ],
]}
"""

你可以扩展类来覆盖列表函数,但这是可选的

class NumericalList(BaseModel):
__root__: Union[StrictInt, StrictFloat, List[NumericalList]]
def __iter__(self):
return iter(self.__root__)
def __getitem__(self, index):
return self.__root__[index]
def __setitem__(self, index, value):
self.__root__[index] = value

NumericalList.update_forward_refs()
n3: NumericalList = NumericalList.parse_obj([0, 5, [1]])
for i in n3:
print(i)
"""
__root__=0
__root__=5
__root__=[NumericalList(__root__=1)]
"""

如果你想获得原生类型你可以做

n3.dict()
"""
{'__root__': [0, 5, [1]]}
"""