当我定义自定义的model_fn wtih tf1.0时,我想在损失为nan时停止训练。我在Model_fn中尝试了以下代码:
return model_fn_lib.ModelFnOps(
mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
training_hooks=[tf.train.NanTensorHook(loss, fail_on_nan_loss=False)])
但是 fail_on_nan_loss = false 仍然会引起异常,我希望它会写下警告消息并停止特定培训而不会引起例外。
关于如何正确使用tf.train.nantensorhook的任何建议?
当我探索解决方案时,一项可能的工作可能会有所帮助:我从basic_session_run_hooks.py复制nantensorhook类,并在下面的model_fn中进行我自己的呼叫版本
class NanTensorHook2(tf.train.SessionRunHook):
"""NaN Loss monitor by Lei.
Monitors loss and stops training if loss is NaN.
Can either fail with exception or just stop training.
"""
def __init__(self, loss_tensor, fail_on_nan_loss=True):
"""Initializes NanLoss monitor.
Args:
loss_tensor: `Tensor`, the loss tensor.
fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
"""
self._loss_tensor = loss_tensor
self._fail_on_nan_loss = fail_on_nan_loss
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._loss_tensor)
def after_run(self, run_context, run_values):
if (np.isnan(run_values.results) or np.isinf(run_values.results)):
failure_message = "Model diverged with loss = NaN or Inf."
if self._fail_on_nan_loss:
logging.error(failure_message)
raise NanLossDuringTrainingError
else:
logging.warning(failure_message)
# We don't raise an error but we request stop without an exception.
run_context.request_stop()
然后使用NantenSorhook2,然后开始工作。
指出我添加了" np.isinf(run_values.results)" ,因为我相信损失= inf也应在此处检查。
任何专家都有更好的解决方案?