如何在dm-haiku中重新排序不同的参数集



dm-haiku中,神经网络的参数在字典中定义,其中键是模块(和子模块)名称。如果您想遍历这些值,有多种方法可以这样做,如本期dm-haiku所示。但是,字典不尊重模块的顺序,因此很难解析子模块。例如,如果我有2个linear层,每个层后面跟着一个mlp层,那么使用hk.data_structures.traverse(params)将(大致)返回:

['linear', 'linear_2', 'mlp/~/1', 'mlp/~/2'].

而我希望它返回:

['linear', 'mlp/~/1', 'linear_2', 'mlp/~/2'].

我想要这种形式的原因是,如果创建一个可逆的神经网络,并希望扭转params被调用的顺序,隔离替代部分用于其他目的(例如迁移学习),或者,一般来说,想要更多地控制如何以及在哪里(重新)使用训练参数。

为了处理这个问题,我已经采取了regex的名称,并把它们放在我想要的顺序,然后使用hk.data_structures.filter(predicate, params)按排序模块名称过滤。虽然,如果每次我想这样做时都必须重新创建一个正则表达式,这是相当乏味的。

我想知道是否有一种方法可以将参数的dm-haiku字典转换为具有层次结构和排序的pytree之类的东西,从而使此更容易?我相信equinox以这种方式处理参数(我将很快深入了解如何做到这一点),但是我想检查一下,看看我是否忽略了一个允许对params的字典进行分组、反转和其他排列的简单方法?

根据源代码https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/filtering.py#L42#L46 haiku使用dict的排序函数(haiku参数自0.0.6以来为香草dict)用于hk.data_structures.traverse。因此,如果不修改函数本身,就无法得到想要的结果。顺便说一下,我不太明白你所说的"颠倒params的顺序"是什么意思。所有参数都在输入中一起传递,然后唯一决定使用顺序的是函数本身的体系结构,所以你应该手动反转正向传递,但你不需要在params中改变一些东西。

最新更新