我想创建一个自定义精度函数,仅当argmax
的值超过阈值时才使用argmax
作为y_pred
,否则为-1。
就支持的Keras而言,它将是sparse_categorical_accuracy
的修改:
return backend.cast(
backend.equal(
backend.flatten(y_true),
backend.cast(backend.argmax(y_pred, axis=-1),
backend.floatx())),
backend.floatx())
所以,而不是:
backend.argmax(y_pred, axis=-1)
我需要一个伪代码逻辑的函数:
argmax_values = backend.argmax(y_pred, axis=-1)
argmax_values if y_pred[argmax_values] > threshold else -1
作为一个具体例子,如果:
x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
和threshold=0.8
,则所需函数的结果将是:
[-1, 0, -1, 0]
我如何使用Keras后端实现这一点?我的Keras版本是2.2.4
,所以我没有访问TensorFlow 2后端。
您可以使用K.switch
根据条件从两个不同的张量中有条件地分配值。使用K.switch
,您想要的函数是:
from keras import backend as K
def argmax_w_threshold(y_pred, threshold=0.8):
argmax_values = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
return K.switch(
K.max(y_pred, axis=-1) > threshold,
argmax_values,
-1. * K.ones_like(argmax_values)
)
注意K.switch
的then
和else
部分的张量必须具有相同的形状,因此使用K.ones_like
。
以你为例:
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
>>> x = [[0.75, 0.25], [0.85, 0.15], [0.5, 0.5], [0.95, 0.05]]
>>> sess.run(argmax_w_threshold(x))
array([-1., 0., -1., 0.], dtype=float32)