从张量流中的张量返回 k 个最小元素



我在张量流中有一个形状(16, 512, 4096)的张量,我想从张量中计算出k个最小的元素。

请注意,我能够使用以下代码片段在 pytorch 中获取它-

#inputs.shape (16L, 512L, 4096L)
dists, inputs_idx = torch.topk(inputs, 64, 2, largest=False, sorted=False)
#dists.shape (16L, 512L, 64L), inputs_idx.shape (16L, 512L, 64L)

请问有什么解决方法吗?

由于tf.math.top_k可用于获取k最大的元素,因此您可以对值求反,执行操作,然后再次否定它们以取回值:

-tf.math.top_k(
-x, k
)

传递sorted=False以获得与您在问题中发送的代码等效的公式。

最新更新