如何在 Keras Lambda 回调中按纪元访问步数



我想知道如何从 Keras lambda 回调内部访问每个纪元的批处理数,即传递给model.fit函数的参数steps_per_epoch的值。

以下是我的自定义回调:

(我想在batch_per_epoch = ???????中填写???????(

class MyBatchLogger(keras.callbacks.Callback):
def __init__(self):
super().__init__()
self._current_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
self._current_epoch = epoch
def on_epoch_end(self, epoch, logs=None):
print("Epoch end", logs)
def on_batch_end(self, batch, logs={}):
batch_per_epoch = ???????
acc = logs["acc"].item()
loss = logs["loss"].item()
mae = logs["mean_absolute_error"].item()
ca = logs["categorical_accuracy"].item()
print(json.dumps({
"timestamp": datetime.now().isoformat(),
"epoch": self._current_epoch,
"batch": batch,
"batchPerEpoch": batch_per_epoch,
"accuracy": acc,
"meanAbsoluteError": mae,
"categoricalAccuracy": ca,
"loss": loss
}))

我正在使用 Keras 2.2.5 和 Tensorflow 1.14.1,但如有必要,我可以更新。

答案可能来得有点晚,但我花了一些时间挖掘它,所以也许这无论如何都是有帮助的。

您需要的信息在这里

self.params.get('steps')

最新更新