我想在我的hydra配置中使用python数据类型作为参数,既有内置的,也有从numpy、tensorflow等库导入的。类似于:
# config.yaml
arg1: np.float32
arg2: tf.float16
我现在正在做这个:
# config.yaml
arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
'float32': np.float32,
'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]
是否有更为的水力/优雅的解决方案?
hydra.utils.get_class
函数能为您解决这个问题吗?
# config.yaml
arg1: numpy.float32 # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)
更新1:使用自定义冲突解决程序
基于miccio下面的评论,这里有一个使用OmegaConf自定义解析器来包装get_class
函数的演示。
from omegaconf import OmegaConf
from hydra.utils import get_class
OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
config = OmegaConf.create("""
# config.yaml
arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")
arg1 = config.arg1
arg1 = config.arg2
更新2:
事实证明,get_class("numpy.float32")
成功了,但get_class("tensorflow.float16")
引发了ValueError。原因是get_class
检查返回的值是否确实是一个类(使用isinstance(cls, type)
(。
函数hydra.utils.get_method
稍微宽松一些,只检查返回的值是否是可调用的,但这仍然不适用于tf.float16
。
>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False
封装tensorflow.as_dtype
函数的自定义解析器可能是合适的。
>>> tf.as_dtype("float16")
tf.float16