为九头蛇定义参数实例化的目标是2的幂



在使用实例化时,是否有办法为目标定义一个参数为2的幂?例如:

from sklearn.feature_extraction import HashingVectorizer
vec = HashingVectorizer(n_features=2**18)
vec.transform(["a quick fox"])
<1x262144 sparse matrix of type '<class 'numpy.float64'>'
with 2 stored elements in Compressed Sparse Row format>

正如预期的那样,输出是一个形状为(1,262144)的稀疏向量,相当于2**18。

但是,在配置文件中不能使用值2**18,因为它是作为字符串传入的。

config.yaml

vec:
_target_: sklearn.feature_extraction.text.HashingVectorizer
n_features: 2**18

test.py

import hydra
import hydra.utils as hu

@hydra.main(config_path='conf', config_name='config')
def main(cfg):
vec = hu.instantiate(cfg.vec)
vec.transform(['Erroneous Monk'])

if __name__ == "__main__":
main()

运行此示例,您将得到以下内容:

python test.py
...
TypeError: n_features must be integral, got '2**18' (<class 'str'>).

是否有一种方法可以通知hydra该值不应被视为字符串?

算术表达式目前不支持OmegaConf(底层配置库)。您可以使用自定义解析器实现某些东西。例如,您可以注册一个名为pow的自定义解析器,它将对两个输入调用Python power函数。

import hydra
import hydra.utils as hu
from omegaconf import OmegaConf
# register the resolver before you access the config field.
OmegaConf.register_new_resolver("pow", lambda x,y: x**y)
@hydra.main(config_path='conf', config_name='config')
def main(cfg):
vec = hu.instantiate(cfg.vec)
vec.transform(['Erroneous Monk'])

if __name__ == "__main__":
main()

您的配置可以定义为:

vec:
_target_: sklearn.feature_extraction.text.HashingVectorizer 
n_features: ${pow:2,18}

最新更新