为什么自定义训练循环不会在batch_size内平均损失?



下面的代码片段是Tensorflow官方教程中的自定义训练循环。https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch。另一个教程也没有计算batch_size的平均损失,如下所示https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough

为什么在行loss_value = loss_fn(y_batch_train, logits)处loss_value没有在batch_size上取平均值?这是个虫子吗?从这里的另一个问题来看,Loss函数使用reduce_m均值,但不使用reduce_sum,确实需要reduce_mean来对批量大小的损失进行平均

loss_fn在本教程中定义如下。它显然不会在batch_size上求平均值。

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

根据文件,keras.losses.SparseCategoricalCrossentropy在没有求平均值的情况下对整个批次的损失求和。因此,这本质上是reduce_sum而不是reduce_mean

Type of tf.keras.losses.Reduction to apply to loss. Default value is AUTO. AUTO indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.

代码如下所示。

epochs = 2
for epoch in range(epochs):
print("nStart of epoch %d" % (epoch,))
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(x_batch_train, training=True)  # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y_batch_train, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %s samples" % ((step + 1) * 64))

我已经弄清楚了,默认情况下,loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)确实是batch_size的平均损失。

最新更新