使用softmax计算熵时数据类型不匹配



我使用下面的代码来计算预测标签和实际标签的熵。数据来源于CIFAR-10数据集

我使用astype(np.float32)将源数据转换为数组,然后在tf.constant()中使用dtype作为float32。错误信息

TypeError: DataType float32 for attr 'Tlabels'不在允许的列表中取值范围:int32, int64

列出只允许int32、int64数据类型。如果没有在以上两个步骤中显式指定数据类型,我将在matmul()操作中遇到障碍,因为在计算中使用的权重项是float数据类型。

f = open('cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='bytes')
#f.close()
X = np.asarray(datadict[b"data"]).astype(np.float32)  #b prefix is for bytes string literal.
Y = np.asarray(datadict[b'labels']).astype(np.float32)
f = open('cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='bytes')
#f.close()
X = np.asarray(datadict[b"data"]).astype(np.float32)  #b prefix is for bytes string literal.
Y = np.asarray(datadict[b'labels']).astype(np.float32)
graph = tf.Graph()
with graph.as_default():
    tf_train_data = tf.constant(X, dtype = tf.float32)
    tf_train_labels = tf.constant(Y, dtype = tf.float32)
    tf_test_data = tf.constant(X_test, dtype = tf.float32)
    tf_test_labels = tf.constant(Y_test, dtype = tf.float32)
    print (tf_train_labels.get_shape())
    weights = tf.Variable(tf.truncated_normal([3072, 10]))
    print (tf.rank(weights))
    biases = tf.Variable(tf.zeros([10]))
    logits = tf.matmul(tf_train_data, weights) + biases
    print (tf.rank(logits), tf.rank(tf_train_labels), tf.rank(tf_test_labels))
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, tf_train_labels))
    optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    train_prediction = tf.nn.softmax(logits)
    test_prediction = tf.nn.softmax(tf.matmul(tf_test_data, weights) + biases)

这是错误信息

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-74-8e1ffbeb5013> in <module>()
     11     logits = tf.matmul(tf_train_data, weights) + biases
     12     print (tf.rank(logits), tf.rank(tf_train_labels), tf.rank(tf_test_labels))
---> 13     loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, tf_train_labels))
     14     optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
     15     train_prediction = tf.nn.softmax(logits)
/Users/ayada/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py in sparse_softmax_cross_entropy_with_logits(logits, labels, name)
    562     if logits.get_shape().ndims == 2:
    563       cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
--> 564           precise_logits, labels, name=name)
    565       if logits.dtype == dtypes.float16:
    566         return math_ops.cast(cost, dtypes.float16)
/Users/ayada/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/gen_nn_ops.py in _sparse_softmax_cross_entropy_with_logits(features, labels, name)
   1538   """
   1539   result = _op_def_lib.apply_op("SparseSoftmaxCrossEntropyWithLogits",
-> 1540                                 features=features, labels=labels, name=name)
   1541   return _SparseSoftmaxCrossEntropyWithLogitsOutput._make(result)
   1542 
/Users/ayada/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)
    527             for base_type in base_types:
    528               _SatisfiesTypeConstraint(base_type,
--> 529                                        _Attr(op_def, input_arg.type_attr))
    530             attrs[input_arg.type_attr] = attr_value
    531             inferred_from[input_arg.type_attr] = input_name
/Users/ayada/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py in _SatisfiesTypeConstraint(dtype, attr_def)
     58           "DataType %s for attr '%s' not in list of allowed values: %s" %
     59           (dtypes.as_dtype(dtype).name, attr_def.name,
---> 60            ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
     61 
     62 
TypeError: DataType float32 for attr 'Tlabels' not in list of allowed values: int32, int64

如何修复

tf.nn.sparse_softmax_cross_entropy_with_logits接受整数类型的稀疏标签。要使用单热浮动标签,请考虑使用tf.nn.softmax_cross_entropy_with_logits

相关内容

最新更新