当配置路径是父文件夹时,Hydra中的模式验证不起作用



我有以下项目设置:

configs/
├── default.yaml
└── trainings
├── data_config
│   └── default.yaml
├── simple.yaml
└── schema.yaml

文件内容如下:

app.py:

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from omegaconf import MISSING, DictConfig, OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
CONFIGS_DIR_PATH = Path(__file__).parent / "configs"
TRAININGS_DIR_PATH = CONFIGS_DIR_PATH / "trainings"

class Sampling(Enum):
UPSAMPLING = 1
DOWNSAMPLING = 2

@dataclass
class DataConfig:
sampling: Sampling = MISSING

@dataclass
class TrainerConfig:
project_name: str = MISSING
data_config: DataConfig = MISSING

# @hydra.main(version_base="1.2", config_path=CONFIGS_DIR_PATH, config_name="default")
@hydra.main(version_base="1.2", config_path=TRAININGS_DIR_PATH, config_name="simple")
def run(configuration: DictConfig):
sampling = OmegaConf.to_container(cfg=configuration, resolve=True)["data_config"]["sampling"]
print(f"{sampling} Type: {type(sampling)}")

def register_schemas():
config_store = ConfigStore.instance()
config_store.store(name="base_schema", node=TrainerConfig)

if __name__ == "__main__":
register_schemas()
run()

配置/default.yaml:

defaults:
- /trainings@: simple
- _self_
project_name: test

配置/培训/simple.yaml:

defaults:
- base_schema
- data_config: default
- _self_
project_name: test

配置/培训/schema.yaml:

defaults:
- data_config: default
- _self_
project_name: test

配置/培训/data_config/default.yaml:

defaults:
- _self_
sampling: DOWNSAMPLING

现在,当我像上面所示的那样运行app.py时,我得到了预期的结果(即,"DOWNSAMPLING"被解析为enum类型)。但是,当我尝试运行应用程序时,它从父目录中的default.yaml构建配置,然后我得到这个错误:

所以,当代码像这样:

...
@hydra.main(version_base="1.2", config_path=CONFIGS_DIR_PATH, config_name="default")
# @hydra.main(version_base="1.2", config_path=TRAININGS_DIR_PATH, config_name="simple")
def run(configuration: DictConfig):
...

我得到下面的错误:

In 'trainings/simple': Could not load 'trainings/base_schema'.
Config search path:
provider=hydra, path=pkg://hydra.conf
provider=main, path=file:///data/code/demos/hydra/configs
provider=schema, path=structured://
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

我不明白为什么指定要使用的模式会导致这个问题。有人知道为什么和可以做什么来解决这个问题吗?

如果您在多个配置文件中使用默认列表,我强烈建议您充分阅读并理解默认列表页面。在默认列表中寻址的配置相对于包含配置的配置组。这个错误告诉你Hydra正在training中查找base_schema,因为加载base_schema的默认列表在training中。

或者在注册时将base_schema放在training中:

config_store.store(group="trainings", name="base_schema", node=TrainerConfig)

或者在默认列表中使用绝对寻址(例如在configs/training/simple.yaml中):

defaults:
- /base_schema
- data_config: default
- _self_

相关内容

  • 没有找到相关文章

最新更新