如何使用索引列表对张量进行切片并组成新的张量



如何使用张量流循环来切片张量并组成新张量意思是:

text_embeding =tf.constant(
#index 0       index 1      index 2
[[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]], 
[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]],
[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]]
] 
)

我想让批处理中的每个张量根据索引的组合获得一个新的值列表index_list=[[0,0],[1,1],[2,2],[0,1],[1,2],[0,2]]

我想得到价值''

[
[
[0.1,0.2,0.3 , 0.1,0.2,0.3], index0,0
[0.4,0.5,0.6 , 0.4,0.5,0.6], index1,1
[0.1,0.2,0.3 , 0.1,0.2,0.3], index2,2
[0.1,0.2,0.3 , 0.4,0.5,0.6], index0,1
[0.4,0.5,0.6 , 0.1,0.2,0.3], index1,2
[0.1,0.2,0.3 , 0.1,0.2,0.3]  index0,2
],
[
[0.1,0.2,0.3 , 0.1,0.2,0.3], index0,0
[0.4,0.5,0.6 , 0.4,0.5,0.6], index1,1
[0.1,0.2,0.3 , 0.1,0.2,0.3], index2,2
[0.1,0.2,0.3 , 0.4,0.5,0.6], index0,1
[0.4,0.5,0.6 , 0.1,0.2,0.3], index1,2
[0.1,0.2,0.3 , 0.1,0.2,0.3]  index0,2
],
[
[0.1,0.2,0.3 , 0.1,0.2,0.3], index0,0
[0.4,0.5,0.6 , 0.4,0.5,0.6], index1,1
[0.1,0.2,0.3 , 0.1,0.2,0.3], index2,2
[0.1,0.2,0.3 , 0.4,0.5,0.6], index0,1
[0.4,0.5,0.6 , 0.1,0.2,0.3], index1,2
[0.1,0.2,0.3 , 0.1,0.2,0.3]  index0,2
]
]

"我的代码是这样的,但batch_size=output_layer_sequence.shape[0]在会话图准备好之前为None,这是错误!

vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[output_layer_sequence.shape[0],1]))  # batch * 2 * hidden_size
for i in range(2):
for j in range(2):
vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[16,1]))  # batch * 2 * hidden_size
# vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[output_layer_sequence.shape[0],1]))  # batch * 2 * hidden_size
vsp_start, vsp_end = tf.split(vsp, 2, 1)  # batch * 1 * hiddensize
vsp_start = tf.squeeze(vsp_start)  # batch  * hiddensize

vsp_end = tf.squeeze(vsp_end)  # batch * hiddensize
vsp = tf.concat([vsp_start, vsp_end], axis=-1, name='concat')  # [batch ,2*hiddensize]
span_logits = tf.matmul(vsp, output_span_weight, transpose_b=True)  # output:[batch,class_labels]
span_logits = tf.nn.bias_add(span_logits, output_span_bias)  # [batch,class_labels]
span_logit_sum.append(span_logits)

谢谢!

使用tf.gather():

text_embeding =tf.constant(
#index 0       index 1      index 2
[[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]], 
[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]],
[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]]
] 
)
index_list = tf.constant([[0,0],[1,1],[2,2],[0,1],[1,2],[0,2]])
output = tf.gather(text_embeding, index_list, axis=2)

相关内容

  • 没有找到相关文章

最新更新