如何使用索引对张量执行操作



下面是我尝试做什么的简短示例。如何根据另一个张量的值访问一个张量?

import tensorflow as tf
import numpy as np

input_pl = tf.placeholder(tf.int32, shape=[None,2]))
mapping_pl = tf.placeholder(tf.int32, shape=[None,2]))
input1 = np.asarray([[1,1],
                    [2,2],
                    [3,3]])

mapping = np.asarray([[0,1],
                  [0,2],
                  [2,2]])
with tf.Graph().as_default():
output = .....
   # add the 0th row of input1 with 1th row of output
   # add the 0th row of input1 with 2th row of output
   # add the 2th row of input1 with 2th row of output
sess = tf.Session()
output.eval(sess)

也许有类似的东西

tf.reset_default_graph()
input1 = tf.constant(np.array([[1,1],[2,2],[3,3]]))
output = tf.Variable(np.ones(input1.get_shape()), dtype=input1.dtype)
mapping = tf.constant(np.asarray([[0,1],[0,2],[2,2]]))
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
input1_indices = tf.reshape(tf.slice(mapping, [0, 0], [3, 1]), [-1])
output_indices = tf.reshape(tf.slice(mapping, [0, 1], [3, 1]), [-1])
input1_values = tf.gather(input1, input1_indices)
print 'updating output rows', output_indices.eval()
print 'with values from input1 rows', input1_indices.eval()
print 'output before', sess.run(output)
sess.run(tf.scatter_update(output, output_indices, input1_values))
print 'output after', sess.run(output)

最新更新