如何编写这个自定义损失函数,使其为每个样本产生损失?



我对ccc使用这个自定义损失函数

def ccc(y_true, y_pred):
ccc = ((ccc_v(y_true, y_pred) + ccc_a(y_true, y_pred)) / 2)
return 1 - ccc
def ccc_v(y_true, y_pred):
x = y_true[:,0]
y = y_pred[:,0]
x_mean = K.mean(x, axis=0)
y_mean = K.mean(y, axis=0)
covar = K.mean( (x - x_mean) * (y - y_mean) )
x_var = K.var(x)
y_var = K.var(y)
ccc = (2.0 * covar) / (x_var + y_var + (x_mean + y_mean)**2)
return ccc
def ccc_a(y_true, y_pred):
x = y_true[:,1]
y = y_pred[:,1]
x_mean = K.mean(x, axis=0)
y_mean = K.mean(y, axis=0)
covar = K.mean( (x - x_mean) * (y - y_mean) )
x_var = K.var(x)
y_var = K.var(y)
ccc = (2.0 * covar) / (x_var + y_var + (x_mean + y_mean)**2)
return ccc

当前损失函数ccc返回一个标量。损失函数被分成2个不同的函数(ccc_vccc_a),因为我也使用它们作为度量。

我从Keras文档和这个问题中读到,自定义损失函数应该返回一个损失列表,每个示例一个。

第一个问题:即使损失函数返回标量,我的模型也可以训练。有那么糟糕吗?如果我使用一个输出是标量的损失函数而不是一个标量列表,训练有什么不同?

第二个问题:我如何重写我的损失函数来返回一个损失列表?我知道我应该避免均值和和,但在我的情况下,我认为这是不可能的,因为没有一个全局均值,而是不同的均值,一个是协方差的分子,两个是方差的分母。

如果你使用tensorflow,有自动计算损失的api

tf.keras.losses.mse()
tf.keras.losses.mae()
tf.keras.losses.Huber()

# Define the loss function
def loss_function(w1, b1, w2, b2, features = borrower_features, targets =     default):
predictions = model(w1, b1, w2, b2)
# Pass targets and predictions to the cross entropy loss
return keras.losses.binary_crossentropy(targets, predictions)
#if your using categorical_crossentropy than return the losses for it.

#convert your image into a single np.array for input
#build your SoftMax model

# Define a sequential model
model=keras.Sequential()
# Define a hidden layer
model.add(keras.layers.Dense(16, activation='relu', input_shape=(784,)))
# Define the output layer
model.add(keras.layers.Dense(4,activation='softmax'))
# Compile the model
model.compile('SGD', loss='categorical_crossentropy',metrics=['accuracy'])
# Complete the fitting operation
train_data=train_data.reshape((50,784))
# Fit the model
model.fit(train_data, train_labels, validation_split=0.2, epochs=3)
# Reshape test data
test_data = test_data.reshape(10, 784)
# Evaluate the model
model.evaluate(test_data, test_labels)

最新更新