如何将数组张量转换为具有特定索引的one_hot


a = tf.constant([20, 1, 5, 3, 123, 4])

我想将其转换为tensor([0,0,0,1,0,0,0])(索引 = 3)

我怎样才能轻松地做到这一点?

我真正尝试做的是这样的:有一个深度神经网络,它有 5 个输出节点(用于分类)。假设一个前馈传播的输出是 [5, 22, 3, 4, 11](类型 tensor )。在此前馈中,标签为 1。所以,我需要打开这个索引的值,并像这样关闭其他索引:[5, 0, 0, 0, 0]。 最后,需要将值更改为 1:[1, 0, 0, 0, 0] 并在网络中反向传播(梯度)此张量。

这段代码应该可以做到。它使用 Numpy:

import numpy as np
def one_hot(y):
  y = y.reshape(len(y))
  n_values = int(np.max(y)) + 1
  return tf.convert_to_tensor(np.eye(n_values)[np.array(y, dtype=np.int32)])

我不确定这是否是您需要的,但我希望它有所帮助。例:

>>> print(one_hot(np.array([2,3,4])))
>>> [[ 0.  0.  1.  0.  0.]
     [ 0.  0.  0.  1.  0.]
     [ 0.  0.  0.  0.  1.]]

您正在寻找的不是独热编码。也许这就是您想要实现的目标:

a = tf.constant([20, 1, 5, 3, 123, 4])
c = tf.cast(tf.equal(a, 3), tf.int32)    # 3 is your matching element
with tf.Session() as sess:
    print(c.eval())
# [0 0 0 1 0 0]

编辑

如果您已经了解索引,则可以通过多种方式执行此操作。如果张量中的值有可能重复,您可以执行以下操作:

a = tf.constant([20, 1, 5, 3, 123, 4, 3])
c = tf.cast(tf.equal(a, a[3]), tf.int32)
with tf.Session() as sess:
    print(c.eval())
# [0 0 0 1 0 0 1]

但是如果你确定值没有重复,你可以借助 numpy 数组构造这个张量,如下所示:

import numpy as np
c = np.zeros((7), np.int32)
c[3] = 1
c_tensor = tf.constant(c)
with tf.Session() as sess:
    print(c_tensor.eval())
# [0 0 0 1 0 0 0]

编辑 2

根据新编辑的问题,为了执行分类任务,并且由于在我看来您没有进行自定义反向传播,因此让我为您提供您正在寻找的部分的骨架代码。

tf.reset_default_graph()
X = tf.placeholder(tf.float32, (None, 224, 224, 3))
y = tf.placeholder(tf.int32, (None))
one_hot_y = tf.one_hot(y, n_outputs)   # Generate one-hot vector
logits = My_Network(X)   # This function returns your network.
cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, one_hot_y)) 
  # This function will compute softmax and get the loss function which you
  # would like to minimize.
optimizer = tf.train.AdamOptimizer(learning_rate = 0.01).minimize(cross_entropy)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for **each epoch**:
       for **generate batches of your data**:
            sess.run(optimizer, feed_dict = {X: batch_x, y: batch_y})

请花一些时间理解代码。我还建议您遵循一些有关分类任务的教程,因为它们非常可用。我建议你通过TensorFlow使用CNN。

相关内容

  • 没有找到相关文章