tf.keras.layers.BatchNormalization with trainable=False似乎不会更



我正在努力了解BatchNormalization层在TensorFlow中的具体表现。我提出了以下代码,据我所知,它应该是一个完全有效的keras模型,但BatchNormalization的均值和方差似乎没有更新。

来自文档https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

在BatchNormalization层的情况下,在层上设置trainable=False意味着该层随后将以推理模式运行(意味着它将使用移动平均值和移动方差来规范当前批次,而不是使用当前批次的平均值和方差(。

我希望模型在每次后续的预测调用中返回不同的值。然而,我看到的是10次返回的完全相同的值。有人能向我解释为什么BatchNormalization层不更新其内部值吗?

import tensorflow as tf
import numpy as np
if __name__ == '__main__':
np.random.seed(1)
x = np.random.randn(3, 5) * 5 + 0.3
bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
z = input = tf.keras.layers.Input([5])
z = bn(z)
model = tf.keras.Model(inputs=input, outputs=z)
for i in range(10):
print(x)
print(model.predict(x))
print()

我使用TensorFlow 2.1.0

好吧,我在假设中发现了错误。移动平均值是在训练期间更新的,而不是像我想的那样在推理期间更新。这是完全合理的,因为在推理过程中更新移动平均值可能会导致不稳定的产生模型(例如,高度病态的输入样本的长序列[例如,它们的生成分布与训练网络的分布大不相同]可能会使网络产生偏误,并导致有效输入样本的性能变差(。

当您微调预训练的模型并希望冻结网络的某些层时,即使在训练期间,可训练参数也很有用。因为当您调用model.predict(x)(甚至model(x)model(x, training=False)(时,层会自动使用移动平均值而不是批处理平均值。

下面的代码清楚地展示了

import tensorflow as tf
import numpy as np
if __name__ == '__main__':
np.random.seed(1)
x = np.random.randn(10, 5) * 5 + 0.3
z = input = tf.keras.layers.Input([5])
z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)
model = tf.keras.Model(inputs=input, outputs=z)

# a dummy loss function
model.compile(loss=lambda x, y: (x - y) ** 2)
# a dummy fit just to update the batchnorm moving averages
model.fit(x, x, batch_size=3, epochs=10)

# first predict uses the moving averages from training
pred = model(x).numpy()
print(pred.mean(axis=0))
print(pred.var(axis=0))
print()

# outputs the same thing as previous predict
pred = model(x).numpy()
print(pred.mean(axis=0))
print(pred.var(axis=0))
print()

# here calling the model with training=True results in update of moving averages
# furthermore, it uses the batch mean and variance as in training, 
# so the result is very different
pred = model(x, training=True).numpy()
print(pred.mean(axis=0))
print(pred.var(axis=0))
print()

# here we see again that the moving averages are used but they differ slightly after
# the previous call, as expected
pred = model(x).numpy()
print(pred.mean(axis=0))
print(pred.var(axis=0))
print()

最后,我发现(https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization)提到这一点:

  1. 当使用包含批量规范化的模型执行推理时,通常(尽管并非总是(希望使用累积统计信息,而不是小批量统计信息。这是通过在调用模型时传递training=False或使用model.prpredict来实现的

希望这将在未来帮助有类似误解的人。

最新更新