Pydantic:类型暗示张量流张量



知道如何使用pydantic来提示tf张量吗?尝试默认tf。张量

RuntimeError: no validator found for <class 'tensorflow.python.framework.ops.Tensor'>, see `arbitrary_types_allowed` in Config

和tf.flaot32

RuntimeError: error checking inheritance of tf.float32 (type: DType)

查看pydantic中的文档,我认为需要定义这样的任意类…

class Tensor:
def __init__(self, Tensor):
self.Tensor = Union[
tensorflow.python.framework.ops.Tensor,
tensorflow.python.framework.sparse_tensor.SparseTensor,
tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor,
tensorflow.python.framework.ops.EagerTensor,
]

main.

class Main(BaseModel):
tensor : Tensor

class Config:
arbitary_types_allowed = True

工作代码:

from pydantic import BaseModel
import tensorflow as tf

class MyTensor(BaseModel):
tensor: tf.Variable
class Config:
arbitrary_types_allowed = True

最新更新