如何在张量流vectorized_map中使用 TopK 运算符



我试图对形状[batchsize, listsize]的张量执行tf.vectorized_map,并将tf.math.top_k运算符应用于批处理中的每一行,但没有成功。

例如,数据可以是:

[ [1,2,4,5,6], [9,5,4,2,1] ]

我想在[1,2,4,5,6][9,5,4,2,1]上应用 topk.

然而,我成功地用tf.map_fn做了同样的事情,但vectorized_map应该跑得更快。我使用张量流 1.15

import tensorflow as tf
import numpy as np
# create fake data
x = tf.convert_to_tensor([
[1,2,4,5,6],
[9,5,4,2,1],
], dtype=tf.float32)
x = tf.reshape(x, (2, -1))
B = x.shape[0] # batchsize
L = x.shape[1] # list size
print(f"B {B}, L {L}")
sess = tf.Session()
print(f"x tensor: {sess.run(x)}n")
def fv(_x):
#_tensor = tf.reshape(_x, (L,))  # doesnt work (1)
_tensor = tf.reshape(tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32), (L,)) # work (2)
#_tensor = tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32) # work (3)
print(f"_tensor: {_tensor}")
values, indices = tf.math.top_k(_tensor, k=3)
# i just need the indices
return indices
indices = tf.vectorized_map(
fv,
x,
)
print("nindices ")
print(sess.run(indices))

正如我们所看到的 (2( 和 (3( 运行,所以 topk 运算符应该是可用的。此外,即使 (1( 不起作用,我也可以使用 _x,例如只需像这样返回它:

def fv(_x):
return _x * 10

所以_x是可用的。

因此,当我使用 (1( 运行代码时,出现错误:

ValueError: No converter defined for TopKV2
name: "loop_body/TopKV2"
op: "TopKV2"
input: "loop_body/Reshape"
input: "loop_body/TopKV2/k"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "sorted"
value {
b: true
}
}
inputs: [WrappedTensor(t=<tf.Tensor 'loop_body/Reshape/pfor/Reshape:0' shape=(2, 5) dtype=float32>, is_stacked=True, is_sparse_stacked=False), WrappedTensor(t=<tf.Tensor 'loop_body/TopKV2/k:0' shape=() dtype=int32>, is_stacked=False, is_sparse_stacked=False)]. 
Either add a converter or set --op_conversion_fallback_to_while_loop=True, which may run slower
Process finished with exit code 1

在这里,我只是尝试获取索引,因为我需要处理向量以在K=3的输出中具有类似的[[0,0,1,1,1], [1,1,1,0,0] ](如果值在 topk 中,则为 1,否则为 0(。并且还要给出另一个形状为 [batchsize, 1] 的张量,其中包含每行的 K 参数。(我已经成功地用map_fn这样做了,所以我认为以后不会有问题(。

也许可以在矢量化映射中实现我自己的 topk 运算符,但我宁愿不这样做。

我终于做了这样的事情: 这不使用vectorized_map,但这是我想做的。但是,如果有人可以使其与vectorized_map一起使用,我会看看解决方案。:)

def topk(x, k):
"""
x : shape [B, L]
k : shape [B, 1]
return : final_mask of shape [B,L] with final_mask[b,i] = 0 if x[b,i] is in  
the k[b] biggest values of x[b,:], else final_mask[b,i] = 1
"""
B = x.shape[0]  # batchsize
L = x.shape[1]  # list size
# the indices sorted in descending order
indices_des = tf.argsort(x, axis=-1, direction='DESCENDING', stable=False, name='sorting_for_topk')

mask = tf.reshape(tf.range(start=0, limit=L, dtype=tf.int32), [1, L])
mask = tf.repeat(mask, [B], axis=0)
mask = mask<k
one_hot = tf.one_hot(indices_des, depth=L) * tf.cast(tf.reshape(mask, [B, L, 1]), tf.float32)
final_mask = tf.reduce_sum(one_hot, axis=1)

return final_mask