我在张量流中有一个形状(16, 512, 4096)
的张量,我想从张量中计算出k
个最小的元素。
请注意,我能够使用以下代码片段在 pytorch 中获取它-
#inputs.shape (16L, 512L, 4096L)
dists, inputs_idx = torch.topk(inputs, 64, 2, largest=False, sorted=False)
#dists.shape (16L, 512L, 64L), inputs_idx.shape (16L, 512L, 64L)
请问有什么解决方法吗?
由于tf.math.top_k
可用于获取k
最大的元素,因此您可以对值求反,执行操作,然后再次否定它们以取回值:
-tf.math.top_k(
-x, k
)
传递sorted=False
以获得与您在问题中发送的代码等效的公式。