获取pytorch张量中的顶部映射值



我试图得到张量的前N个元素。我有一个映射告诉我如何对张量值排序

values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor = torch.tensor([1, 4, 12, 2])
tensor.topk(3)

这里的结果应该是torch.tensor([1, 12, 2]),即使用values_mapping

映射它们后的顶值。是否有办法这样做,使用火炬?我们能告诉torch如何排序它得到的值吗?

我不知道是否存在更优雅的解决方案,但是您可以使用所需键选择值,从值中选择前k项,使用topk索引选择键:

values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor0 = torch.tensor([1, 4, 12, 2])
mapping0 = torch.Tensor([(k, v) for k, v in values_mapping.items() if k in tensor0])
topk = mapping0[:,1].topk(3)
top_keys = mapping0[:,0][topk.indices]
print(top_keys)
>>> tensor([ 2.,  1., 12.])

最新更新