如何更新tensorflow随机加权平均(SWA)中的权重?



我对如何实现tfa的SWA优化器感到困惑。这里有两点:

  1. 当你看文档时,它指向你[这个]模型平均教程。该教程使用了tfa.callback。AverageModelCheckpoint,它允许你
  • 为模型分配移动平均权值,并保存
  • (或)保留旧的非平均权值,但保存的模型使用平均权值

拥有一个允许您保存移动平均权重(而不是当前权重)的独特ModelCheckpoint是有意义的。然而,似乎SWA应该管理权重平均。这让我想设置update_weights=False

正确吗?本教程使用update_weights=True

  1. 在文档中有一个关于SWA不更新BN层的说明。按照建议,我这样做了,
# original training
model.fit(...)
# updating weights from final run 
optimizer.assign_average_vars(model.variables)
# batch-norm-hack: lr=0 as suggested https://stackoverflow.com/a/64376062/607528
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0),
loss=loss,
metrics=metrics)
model.fit(
data,
validation_data=None,
epochs=1,
callbacks=final_callbacks)

保存我的模型。

正确吗?

谢谢!

处理批规范最简单的方法如下:首先,循环遍历模型中的所有层,并重置批规范层中的移动平均值和移动方差(在我的示例中,我假设批规范层以"bn"结束):

for l in model.layers:
if l.name.split('_')[-1] == 'bn': # e.g. conv1_bn
l.moving_mean.assign(tf.zeros_like(l.moving_mean))
l.moving_variance.assign(tf.ones_like(l.moving_variance))

之后运行你的模型一个epoch,并将training设置为true以更新移动平均线和方差:

count = 0
for x,_ in dataset_train:
_ = model(x, training = True)
count += 1
if count > steps_per_epoch:
break

有两种方法要做到这一点,首先是您手动更新权重在保存之前,像文档中的这个例子一样。

import tensorflow as tf
import tensorflow_addons as tfa
model = tf.Sequential([...])
opt = tfa.optimizers.SWA(
tf.keras.optimizers.SGD(lr=2.0), 100, 10)
model.compile(opt, ...)
model.fit(x, y, ...)
# Update the weights to their mean before saving
opt.assign_average_vars(model.variables)
model.save('model.h5')

第二个选项是通过更新权重如果设置为update_weights = True,则为AverageModelCheckpoint。如collab笔记本示例所示

avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir, 
update_weights=True)
...
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])

注意AverageModelCheckpoint在保存模型之前也调用assign_average_vars,从源代码:

def _save_model(self, epoch, logs):
optimizer = self._get_optimizer()
assert isinstance(optimizer, AveragedOptimizerWrapper)
if self.update_weights:
optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
...

最新更新