KERAS自定义分类错误数量



我正在使用tensorflow后端的keras,并尝试编写一个自定义损耗函数,该函数仅计算不正确的分类预测数量。这是我的尝试:

def error_count_loss(yTrue, yPred):
    """Sum and return the number of incorrect predictions.
    Parameters
    ----------
    yTrue : One-hot encoded truth
    yPred : Softmax encoded prediction
    """
    yTrue_argmax = K.argmax(yTrue, axis=1)
    yPred_argmax = K.argmax(yPred, axis=1)
    incorrect_bool = K.not_equal(yTrue_argmax, yPred_argmax)
    incorrect_float = K.cast(incorrect_bool, 'float32')
    return K.sum(incorrect_float)

此代码失败,因为Argmax不是可区分的。是否有一种可靠的方法来计算错误的预测?

您可以尝试使用具有梯度的其他功能来编写与之接近的内容,例如:

def error_count_loss(yTrue, yPred):
    return K.sum(K.abs(K.sign(yTrue) - K.sign(yFalse)))

,但这不是训练的最佳损失功能。尝试查看cancorical_crossentropy。

最新更新