Tensorflow 2.0:模型检查点的自定义度量(平衡精度分数)不起作用



我想实现一个基于平衡准确性分数的模型检查点回调。为此,我实现了以下类:

class BalAccScore(keras.callbacks.Callback):
def __init__(self, validation_data=None):
super(BalAccScore, self).__init__()
self.validation_data = validation_data

def on_train_begin(self, logs={}):
self.balanced_accuracy = []
def on_epoch_end(self, epoch, logs={}):
y_predict = tf.argmax(self.model.predict(self.validation_data[0]), axis=1)
y_true = tf.argmax(self.validation_data[1], axis=1)
balacc = balanced_accuracy_score(y_true, y_predict)
self.balanced_accuracy.append(round(balacc,6))
logs["val_bal_acc"] = balacc
keys = list(logs.keys())
print("n ------ validation balanced accuracy score: %f ------n" %balacc)

然后我定义以下回调

balAccScore = BalAccScore(validation_data=(X_2, y_2))
mc = ModelCheckpoint(filepath=callback_path, monitor="val_bal_acc", verbose=1, save_best_only=True, save_freq='epoch')
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['val_bal_acc'])
history = model.fit(X_1, y_1, epochs = 5, batch_size = 512,
callbacks=[balAccScore,  mc],
validation_data = (X_2, y_2)
)

然后我得到错误

ValueError: Unknown metric function: val_bal_acc

尽管事实上,当使用例如准确性时,即通过在编译时设置metrics=["acc"],我发现它是历史性的。在这种情况下,我得到了预期的警告:

WARNING:tensorflow:Can save best model only with val_bal_acc available, skipping.

但除此之外,该模型运行良好。不确定它为什么不运行。

您应该只删除compile:中的引号

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=[val_bal_acc])

或者至少这是它在R 中的工作方式

编译模型时,由于没有将balanced_accuracy_score作为metrics参数的值传递,因此会出现此错误。编译模型时在metrics参数中传递的字符串"val_bal_acc"不起作用,因为它不是已知的度量。您可以通过它们的字符串名称访问度量,仅针对那些已经在tf.keras.metrics中实现的度量。如果你想在训练期间监控验证平衡的准确性,你应该实现一个自定义的度量类(你可以在这里看看如何做到(,然后把它传递给metrics参数。完成此操作后,如果您想在验证期间监视自定义度量,则可以使用前缀为"val"的名称来监视自定义度量。不需要像您那样进行补充的自定义回调,一旦定义了度量,日志就会自动更新。对于这个特殊的情况,您可以在这个问题的答案中找到这个度量的一些实现。

如果您更喜欢回调方法,则不需要定义自定义度量,而是利用已经记录的度量。你可以在我的回答中找到它的实现。

最新更新