前向传递后的KL损失究竟在哪里使用?



我注意到,当从DenseVariationalcall方法调用self.add_loss时(即在正向传递期间),损失的KL部分被添加到Layer类的列表self._losses中。

但是,在训练期间如何处理此列表self._losses(或同一Layer类的方法losses)?在训练期间从哪里调用它?例如,在将它们添加到最终损失之前,它们是求和还是平均值?我想看看实际的代码。

我想知道这些损失与您在fit方法中指定的损失究竟是如何结合的。你能为我提供组合它们的代码吗?请注意,我对TensorFlow附带的Keras感兴趣(因为这是我正在使用的那个)。

实际上,计算总损失的部分是Modelcompile方法,特别是在以下行中:

# Compute total loss.
# Used to keep track of the total loss value (stateless).
# eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
#                   loss_weight_2 * output_2_loss_fn(...) +
#                   layer losses.
self.total_loss = self._prepare_total_loss(masks)

_prepare_total_loss方法将正则化和层损失添加到总损失中(即将所有损失相加),然后在以下行的批处理轴上将它们平均:

# Add regularization penalties and other layer-specific losses.
for loss_tensor in self.losses:
total_loss += loss_tensor
return K.mean(total_loss)

实际上,self.losses不是Model类的属性;相反,它是父类的属性,即Network,它将所有特定于层的损失作为列表返回。此外,为了解决任何混淆,total_loss上面的代码是一个单一的张量,它是模型中所有损失(即损失函数值和特定于层的损失)的总的等式。请注意,根据定义,损失函数必须为每个输入样本(而不是整个批次)返回一个损失值。因此,K.mean(total_loss)会将批处理轴上的所有这些值平均为一个最终损失值,优化器应将其最小化。


至于tf.keras这或多或少与原生keras相同;然而,事物的结构和流程有点不同,下面将解释。

首先,在compileModel方法中,创建一个损失容器,用于保存和计算损失函数的值:

self.compiled_loss = compile_utils.LossesContainer(
loss, loss_weights, output_names=self.output_names)

接下来,在类train_stepModel方法中,调用此容器来计算批处理的损失值:

loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)

如上所示,self.losses被传递给此容器。与本机 Keras 实现一样,self.losses包含所有特定于层的损失值,唯一的区别是tf.keras它是在Layer类中实现的(而不是像本机 Keras 中那样Network类中实现的)。请注意,ModelNetwork的子类,而 本身就是Layer的子类。现在,让我们看看在__call__LossesContainer方法中如何处理regularization_losses(这些行):

if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
loss_obj.reduction == losses_utils.ReductionV2.AUTO):
loss_value = losses_utils.scale_loss_for_distribution(loss_value)

loss_values.append(loss_value)
loss_metric_values.append(loss_metric_value)

if regularization_losses:
regularization_losses = losses_utils.cast_losses_to_common_dtype(
regularization_losses)
reg_loss = math_ops.add_n(regularization_losses)
loss_metric_values.append(reg_loss)
loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))

if loss_values:
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
loss_metric_values)
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(
total_loss_metric_value, sample_weight=batch_dim)

loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
total_loss = math_ops.add_n(loss_values)
return total_loss

如您所见,regularization_losses将被添加到total_loss中,该将保存特定于层的损失的总和以及批处理轴上所有损失函数的平均值之和(因此,它将是单个值)。

最新更新