可以从模块的nn.compact内部更新模块的参数吗?(自修改网络)



我对亚麻很陌生,我想知道正确的方法是什么:

param = f.init(key,x) 
new_param, y = f.apply(param,x) 

其中f为nn。模块实例。
其中f可能经过多个操作来获得new_param,而这些操作可能依赖于中间参数来产生它们的输出。

基本上,是否有一种方法可以访问并更新提供给nn实例的参数。模块从__call__,同时不丢失函数属性,这样就可以用梯度函数变换来包装。

您可以将参数视为可变的var。参考https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html

@nn.compact
def __call__(self, x):
some_params = self.variable('mutable_params', 'some_params', init_fn)
# 'mutable_params' is the variable collection name
# at the same "level" as 'params'
vars_init = model.init(key, x)
# vars_init = {'params': nested_dict_for_params, 'mutable_params': nested_dict_for_mutable_params}
y, mutated_vars = model.apply(vars_init, x, mutable=['mutable_params'])
vars_new = vars_init | mutated_vars # I'm not sure frozendict support | op
# equiv to vars_new = {'params': vars_init['params'], 'mutable_params': mutated_vars['mutable_params']}

最新更新