使用数据类型或其他特定于库的变量作为hydra中的参数



我想在我的hydra配置中使用python数据类型作为参数,既有内置的,也有从numpy、tensorflow等库导入的。类似于:

# config.yaml
arg1: np.float32
arg2: tf.float16

我现在正在做这个:

# config.yaml
arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
'float32': np.float32,
'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]

是否有更为的水力/优雅的解决方案?

hydra.utils.get_class函数能为您解决这个问题吗?

# config.yaml
arg1: numpy.float32  # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)

更新1:使用自定义冲突解决程序

基于miccio下面的评论,这里有一个使用OmegaConf自定义解析器来包装get_class函数的演示。

from omegaconf import OmegaConf
from hydra.utils import get_class
OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))
config = OmegaConf.create("""
# config.yaml
arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")
arg1 = config.arg1
arg1 = config.arg2

更新2:

事实证明,get_class("numpy.float32")成功了,但get_class("tensorflow.float16")引发了ValueError。原因是get_class检查返回的值是否确实是一个类(使用isinstance(cls, type)(。

函数hydra.utils.get_method稍微宽松一些,只检查返回的值是否是可调用的,但这仍然不适用于tf.float16

>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False

封装tensorflow.as_dtype函数的自定义解析器可能是合适的。

>>> tf.as_dtype("float16")
tf.float16

相关内容

  • 没有找到相关文章

最新更新