我试图为Keras定义一个自定义rmse损失函数。我写了下面的函数
import keras.backend as K
def custom_rmse(y_true, y_pred):
loss = K.square(y_pred - y_true)
for i in range(len(y_true)):
for j in range(y_true.shape[1]):
tmp = float(y_true[i][j])
if (tmp < 0.15):
loss[i][j] *= 0.2
else:
loss[i][j] *=0.8
loss = K.sqrt(K.sum(loss, axis=1))
return loss
但是当我运行模型并试图修复它时,我一直得到这个错误
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
<ipython-input-95-efab27dd2563>:8 custom_rmse *
if (tmp < 0.15):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1172 if_stmt
_tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1219 _tf_if_stmt
cond, aug_body, aug_orelse, strict=True)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py:549 new_func
return func(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/cond_v2.py:88 cond_v2
op_return_value=pred)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1197 aug_body
set_state(init_vars)
/tmp/tmp_3e6lmrw.py:35 set_state
(loss[i][j],) = vars_
TypeError: 'Tensor' object does not support item assignment
我将感谢如何解决这个问题的建议。谢谢。
If-Else语句通常不是用于损失函数的方法。大多数时候,最好做一个"软"字。你想要达到的目标。这可以通过(例如)以以下方式对损失值使用陡峭的逻辑函数来实现:
def custom_rmse(y_true, y_pred):
loss = K.square(y_pred - y_true)
logistic_values = tf.sigmoid(1000 * (y_true - 0.15))
loss = logistic_values * loss * 0.8 + (1-logistic_values * loss * 0.2)
loss = K.sqrt(K.sum(loss, axis=1))
return loss
这段代码将做以下事情:
- 我们从y_true中减去0.15(您的阈值),因此新值的阈值现在为0。
- 我们将结果乘以一个较大的数字(这里我选择1000,数字越大,"软阈值"就越陡)。是。这意味着,所有高于阈值的值现在都是非常高的正值,所有低于阈值的值现在都是高负值。
- 我们对结果值应用s形函数。对于所有高正值,该函数将为1,对于所有高负值,该函数将为-0(中间有软过渡)。
- 现在,我们只需将损失乘以logistic_values或1-logistic_values,它基本上充当掩码,分别屏蔽掉所有0或1的值。所有未被掩盖的值现在可以乘以它们各自的0.8或0.2的因子。