tf api usage : stop_if_no_decrease_hook



我想在对象检测中使用stop_if_no_decrease_hook tf api。

但我无法停止训练过程。

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
estimator,
metric_name='loss',
max_steps_without_decrease=5,
min_steps=0
)
early_stopping2 = tf.contrib.estimator.stop_if_lower_hook(
estimator,
metric_name='total_loss',
threshold=10,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None
)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping, early_stopping2])

我不知道什么是正确的metric_name。

import numpy as np
import tensorflow as tf
import logging
from tensorflow.python.training import session_run_hook
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]
def before_run(self, run_context):  # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
current = run_values.results
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()

最新更新