我有以下项目设置:
configs/
├── default.yaml
└── trainings
├── data_config
│ └── default.yaml
├── simple.yaml
└── schema.yaml
文件内容如下:
app.py:
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from omegaconf import MISSING, DictConfig, OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
CONFIGS_DIR_PATH = Path(__file__).parent / "configs"
TRAININGS_DIR_PATH = CONFIGS_DIR_PATH / "trainings"
class Sampling(Enum):
UPSAMPLING = 1
DOWNSAMPLING = 2
@dataclass
class DataConfig:
sampling: Sampling = MISSING
@dataclass
class TrainerConfig:
project_name: str = MISSING
data_config: DataConfig = MISSING
# @hydra.main(version_base="1.2", config_path=CONFIGS_DIR_PATH, config_name="default")
@hydra.main(version_base="1.2", config_path=TRAININGS_DIR_PATH, config_name="simple")
def run(configuration: DictConfig):
sampling = OmegaConf.to_container(cfg=configuration, resolve=True)["data_config"]["sampling"]
print(f"{sampling} Type: {type(sampling)}")
def register_schemas():
config_store = ConfigStore.instance()
config_store.store(name="base_schema", node=TrainerConfig)
if __name__ == "__main__":
register_schemas()
run()
配置/default.yaml:
defaults:
- /trainings@: simple
- _self_
project_name: test
配置/培训/simple.yaml:
defaults:
- base_schema
- data_config: default
- _self_
project_name: test
配置/培训/schema.yaml:
defaults:
- data_config: default
- _self_
project_name: test
配置/培训/data_config/default.yaml:
defaults:
- _self_
sampling: DOWNSAMPLING
现在,当我像上面所示的那样运行app.py
时,我得到了预期的结果(即,"DOWNSAMPLING"
被解析为enum类型)。但是,当我尝试运行应用程序时,它从父目录中的default.yaml
构建配置,然后我得到这个错误:
所以,当代码像这样:
...
@hydra.main(version_base="1.2", config_path=CONFIGS_DIR_PATH, config_name="default")
# @hydra.main(version_base="1.2", config_path=TRAININGS_DIR_PATH, config_name="simple")
def run(configuration: DictConfig):
...
我得到下面的错误:
In 'trainings/simple': Could not load 'trainings/base_schema'.
Config search path:
provider=hydra, path=pkg://hydra.conf
provider=main, path=file:///data/code/demos/hydra/configs
provider=schema, path=structured://
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
我不明白为什么指定要使用的模式会导致这个问题。有人知道为什么和可以做什么来解决这个问题吗?
如果您在多个配置文件中使用默认列表,我强烈建议您充分阅读并理解默认列表页面。在默认列表中寻址的配置相对于包含配置的配置组。这个错误告诉你Hydra正在training中查找base_schema,因为加载base_schema的默认列表在training中。
或者在注册时将base_schema放在training中:
config_store.store(group="trainings", name="base_schema", node=TrainerConfig)
或者在默认列表中使用绝对寻址(例如在configs/training/simple.yaml中):
defaults:
- /base_schema
- data_config: default
- _self_