如何使用 Keras 作为高级 API 在张量流上实现批量规范化



BatchNormalization (BN( 在训练和推理中的操作略有不同。在训练中,它使用当前小批量的平均值和方差来扩展其输入;这意味着应用批量归一化的确切结果不仅取决于当前输入,还取决于小批量的所有其他元素。在推理模式下,这显然是不可取的,因为我们知道这是一个确定性的结果。因此,在这种情况下,使用整个训练集的全局平均值和方差的固定统计量。

在 Tensorflow 中,此行为由布尔开关控制,training调用层时需要指定该开关,请参阅 https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization。使用 Keras 高级 API 时如何处理此切换?我假设它是自动处理的是否正确,具体取决于我们使用的是model.fit(x, ...)还是model.predict(x, ...)


为了测试这一点,我写了这个例子。我们从随机分布开始,想要对输入是正还是负进行分类。但是,我们还有一个来自不同分布的测试数据集,其中输入被 2 取代(因此标签检查 x>2(。

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization
np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)
xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)
xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)
x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

运行代码会打印值 0.5,这意味着一半的示例已正确标记。如果系统使用训练集的全局统计数据来实现 BN,这就是我所期望的。

如果我们将 BN 层更改为读取

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

并再次运行代码,我们发现 0.87。强制始终处于训练状态,正确预测的百分比已更改。这与model.predict(x, ...)现在使用小批量的统计数据来实现BN的想法一致,因此能够稍微"纠正"训练数据和测试数据之间源分布中的不匹配。

这是对的吗?

如果我正确理解了您的问题,那么是的,keras 确实会根据fitpredict/evaluate自动管理训练与推理行为。该标志称为learning_phase,它决定了批处理规范、dropout 和潜在的其他事物的行为。当前的学习阶段可以用keras.backend.learning_phase()看到,用keras.backend.set_learning_phase()设置。

https://keras.io/backend/#learning_phase

最新更新