Keras 自定义损失函数 dtype 错误



我有一个NN有两个相同的CNN(类似于Siamese网络(,然后合并输出,并打算在合并的输出上应用自定义损失函数,如下所示:

-----------------        -----------------
|    input_a    |        |    input_b    |
-----------------        -----------------
| base_network  |        | base_network  |
------------------------------------------
|           processed_a_b                |
------------------------------------------

在我的自定义损失函数中,我需要将 y 垂直分成两部分,然后在每部分应用分类交叉熵损失。但是,我不断从损失函数中收到 dtype 错误,例如:

值错误回溯(最近一次调用( 最后( 在 (( 中 ----> 1 model.compile(loss=categorical_crossentropy_loss, optimizer=RMSprop(((

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in 编译(自身、优化器、损失、指标、loss_weights、 sample_weight_mode,**夸格斯( 909 loss_weight = loss_weights_list[i] 910 output_loss = weighted_loss(y_true, y_pred, --> 911 sample_weight,口罩( 912 如果 len(自输出(> 1: 913 self.metrics_tensors.附录(output_loss(

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in 加权(y_true、y_pred、砝码、掩码( 451 # 应用样本权重 452 如果权重不是 None: --> 453 score_array *= 权重 454 score_array/= K.mean(K.cast(K.not_equal(weights, 0(, K.floatx(((( 455 返回 K.均值(score_array(

/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_ops.py 在binary_op_wrapper(x, y( 827 如果不是 isinstance(y, sparse_tensor.稀疏张量(: 828 尝试: --> 829 y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y"( 830 除了类型错误: 831# 如果 RHS 不是张量,则可能是张量感知对象

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py 在convert_to_tensor(值、dtype、名称、preferred_dtype( 674 名称=名称, 675 preferred_dtype=preferred_dtype, --> 676 as_ref=假( 677 678

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py 在internal_convert_to_tensor(值、dtype、name、as_ref、 preferred_dtype( 739 740 如果 ret 为"无": --> 741 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref( 742 743 如果 ret 未实现:

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py 在_TensorTensorConversionFunction(t、dtype、name、as_ref( 612 提高值错误( 613 "张量转换请求 dtype %s 对于具有 dtype %s: %r 的张量" --> 614 % (dtype.name, t.dtype.name, str(t((( 615 返回 t 616

值错误:张量转换请求 dtype float64 用于张量 dtype float32: 'Tensor("processed_a_b_sample_weights_1:0", shape=(?,(, dtype=float32('

这是重现错误的 MWE:

import tensorflow as tf
from keras import backend as K
from keras.layers import Input, Dense, merge, Dropout
from keras.models import Model, Sequential
from keras.optimizers import RMSprop
import numpy as np
# define the inputs
input_dim = 10
input_a = Input(shape=(input_dim,), name='input_a')
input_b = Input(shape=(input_dim,), name='input_b')
# define base_network
n_class = 4
base_network = Sequential(name='base_network')
base_network.add(Dense(8, input_shape=(input_dim,), activation='relu'))
base_network.add(Dropout(0.1))
base_network.add(Dense(n_class, activation='relu'))
processed_a = base_network(input_a)
processed_b = base_network(input_b)
# merge left and right sections
processed_a_b = merge([processed_a, processed_b], mode='concat', concat_axis=1, name='processed_a_b')
# create the model
model = Model(inputs=[input_a, input_b], outputs=processed_a_b)
# custom loss function
def categorical_crossentropy_loss(y_true, y_pred):
# break (un-merge) y_true and y_pred into two pieces
y_true_a, y_true_b = tf.split(value=y_true, num_or_size_splits=2, axis=1)
y_pred_a, y_pred_b = tf.split(value=y_pred, num_or_size_splits=2, axis=1)
loss = K.categorical_crossentropy(output=y_pred_a, target=y_true_a) + K.categorical_crossentropy(output=y_pred_b, target=y_true_b) 
return K.mean(loss)
# compile the model
model.compile(loss=categorical_crossentropy_loss, optimizer=RMSprop())

正如您的错误所指示的那样,您正在使用float32数据,并且它期望float64.有必要将错误跟踪到其特定行,以确定要纠正的张量并能够更好地帮助您。

但是,它似乎K.mean()方法有关,但ValueError也可以通过K.categorical_crossentropy()方法生成。因此,问题可能出在张量loss,既y_preds,也y_true两者。鉴于这些情况,我认为您可以尝试两件事来解决问题:

  1. 您可以将张量(假设它是loss(转换为所需的(float64(类型,如下所示:

    from keras import backend as K
    new_tensor = K.cast(loss, dtype='float64')
    
  2. 通过将参数dtype传递给Input()调用(如以下示例中所建议的那样(,可以在开始时将输入声明为 float64 类型,如下所示:

    input_a = Input(shape=(input_dim,), name='input_a', dtype='float64')
    

相关内容

  • 没有找到相关文章

最新更新