Keras后端:argmax如果超过阈值,否则-1



我想创建一个自定义精度函数,仅当argmax的值超过阈值时才使用argmax作为y_pred,否则为-1。

就支持的Keras而言,它将是sparse_categorical_accuracy的修改:

return backend.cast(
backend.equal(
backend.flatten(y_true),
backend.cast(backend.argmax(y_pred, axis=-1),
backend.floatx())),
backend.floatx())

所以,而不是:

backend.argmax(y_pred, axis=-1)

我需要一个伪代码逻辑的函数:

argmax_values = backend.argmax(y_pred, axis=-1)
argmax_values if y_pred[argmax_values] > threshold else -1

作为一个具体例子,如果:

x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]

threshold=0.8,则所需函数的结果将是:

[-1, 0, -1, 0]

我如何使用Keras后端实现这一点?我的Keras版本是2.2.4,所以我没有访问TensorFlow 2后端。

您可以使用K.switch根据条件从两个不同的张量中有条件地分配值。使用K.switch,您想要的函数是:

from keras import backend as K
def argmax_w_threshold(y_pred, threshold=0.8):
argmax_values = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
return K.switch(
K.max(y_pred, axis=-1) > threshold,
argmax_values, 
-1. * K.ones_like(argmax_values)
)

注意K.switchthenelse部分的张量必须具有相同的形状,因此使用K.ones_like

以你为例:

>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
>>> x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
>>> sess.run(argmax_w_threshold(x))
array([-1.,  0., -1.,  0.], dtype=float32)

最新更新