tf.train.nantensorhook(损失,fail_on_nan_loss = false)仍将以TF1.0提



当我定义自定义的model_fn wtih tf1.0时,我想在损失为nan时停止训练。我在Model_fn中尝试了以下代码:

    return model_fn_lib.ModelFnOps(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        training_hooks=[tf.train.NanTensorHook(loss, fail_on_nan_loss=False)])

但是 fail_on_nan_loss = false 仍然会引起异常,我希望它会写下警告消息并停止特定培训而不会引起例外。

关于如何正确使用tf.train.nantensorhook的任何建议?

当我探索解决方案时,一项可能的工作可能会有所帮助:我从basic_session_run_hooks.py复制nantensorhook类,并在下面的model_fn中进行我自己的呼叫版本

        class NanTensorHook2(tf.train.SessionRunHook):
      """NaN Loss monitor by Lei.
      Monitors loss and stops training if loss is NaN.
      Can either fail with exception or just stop training.
      """
      def __init__(self, loss_tensor, fail_on_nan_loss=True):
        """Initializes NanLoss monitor.
        Args:
          loss_tensor: `Tensor`, the loss tensor.
          fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
        """
        self._loss_tensor = loss_tensor
        self._fail_on_nan_loss = fail_on_nan_loss
      def before_run(self, run_context):  # pylint: disable=unused-argument
        return tf.train.SessionRunArgs(self._loss_tensor)
      def after_run(self, run_context, run_values):
        if (np.isnan(run_values.results) or np.isinf(run_values.results)):
          failure_message = "Model diverged with loss = NaN or Inf."
          if self._fail_on_nan_loss:
            logging.error(failure_message)
            raise NanLossDuringTrainingError
          else:
            logging.warning(failure_message)
            # We don't raise an error but we request stop without an exception.
            run_context.request_stop()

然后使用NantenSorhook2,然后开始工作。

指出我添加了" np.isinf(run_values.results)" ,因为我相信损失= inf也应在此处检查。

任何专家都有更好的解决方案?

相关内容

最新更新