使用pydantic验证类(hydra核心列表)



1。上下文

如何在pydantic中验证特定的类?

我使用pydantic来验证hydra解析的yaml列表参数,稍后将传递给建模例程。问题是,hydra字典包含的不是值列表,而是包含这些值的类。如何验证这些参数?

2.示例

在以下示例中,有2个文件:

包含待验证参数的cfg.yaml
  • 包含加载和验证cfg.yaml的指令的main.py
  • 2.1配置文件cfg.yaml

    params_list:
    - 10
    - 0
    - 20
    

    2.2分析器/验证器文件main.py

    import hydra
    import pydantic
    from omegaconf import DictConfig, OmegaConf
    from typing import List
    class Test(pydantic.BaseModel):
    params_list: List[int]
    @hydra.main(config_path=".", config_name="cfg.yaml")
    def go(cfg: DictConfig):
    parsed_cfg = Test(**cfg)
    print(parsed_cfg)
    if __name__ == "__main__":
    go()
    

    3.问题

    当执行python3 main.py时,出现以下错误

    值不是有效列表(type=type_error.list)

    这是因为hydra有一个用于处理列表的特定类,称为omegaconf.listconfig.ListConfig,可以通过添加进行检查

    print(type(cfg['params_list']))
    

    就在CCD_ 8函数定义之后。

    4.指导

    我知道我可能必须告诉pydantic来验证这个特定的东西,但我只是不知道具体如何验证。

    • 这里提供了一些提示,但我想这似乎对任务很有用
    • 另一个想法是为数据属性创建一个泛型类型(如params_list: Generic),然后使用验证器装饰器将其转换为列表,大致如下:
    class ParamsList(pydantic.BaseModel):
    params_list: ???????? #i don't know that to do here
    @p.validator("params_list")
    @classmethod
    def validate_path(cls, v) -> None:
    """validate if it's a list"""
    if type(list(v)) != list:
    raise TypeError("It's not a list. Make it become a list")
    return list(v)
    

    帮助!:你知道怎么解决吗

    如何重新创建示例

    1. 在文件夹中添加第2.1节和第2.2节中描述的文件
    2. 同时使用包pydantichydra-core创建一个requirements.txt文件
    3. 创建并激活env后,运行python3 main.py
    Pydantic不接受DictConfig格式。当您尝试使用pydantic模型解析hydra配置时,必须首先将DictConfig转换为本地Python Dict.OmegaConf.to_object(cfg)

    我假设您使用的是Python 3.10或更高版本。注意使用version_base="1.2"可以获得最新的hydra版本。

    这应该有效:

    import hydra
    import pydantic
    from omegaconf import DictConfig, OmegaConf
    class Test(pydantic.BaseModel):
    params_list: list[int]
    
    @hydra.main(config_path=".", config_name="cfg.yaml", version_base="1.2")
    def go(cfg: DictConfig):
    print(cfg)
    d_cfg = OmegaConf.to_object(cfg)
    parsed_cfg = Test(**d_cfg)
    print(parsed_cfg)
    
    if __name__ == "__main__":
    go()