打印TF 2.x中每个epoch的损失值



我在TF 2.3中编写了自定义损失函数,其中损失由几个子损失组成,因为我想跟踪子损失,所以我使用tf.print打印了它们:

def custom_loss_envelop(model_inputs,  model, num_bound,model_outputs,lambda_, ener):
def custom_loss(y_true,y_pred):
l1 = ....
l2 = ....
l3 = ....
tf.print("l1:",tf.math.round(l1 * 100)/100,", l2:", tf.math.round(l2 * 100) / 100,
", l3:", tf.math.round(l3 * 100) / 100,
", l4:", tf.math.round(l4 * 100) / 100)
loss = l1 + l2 + l3 + l4
return loss
return custom_loss

问题是这段代码打印每批损失,而我只希望它每个epoch。有什么办法吗?

您可以使用keras.callbacks.Callback()类与以下定义的函数:

def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys()) 
# you could do more here ...            
print("End epoch {} of training; got log keys: {}".format(epoch, keys))

注意epoch的结束只在训练期间有效。查看文档中的其他选项https://www.tensorflow.org/guide/keras/custom_callback?hl=en

相关内容

  • 没有找到相关文章

最新更新