Python根据目标与元素之间的差异对张量进行排序



假设原张量在下面

tensor([[0.9950, 0.6175, 0.1253, 1.3536],
[0.1208, 0.4237, 1.1313, 0.9022],
[1.1995, 0.0699, 0.4396, 0.8043]])

我想根据1和元素之间的差值对张量进行排序,更接近1的元素会在张量的前面,所以排序后的张量会
sorted_tensor([[ 0.9950, 1.3536, 0.6175, 0.1253],
[ 0.9022, 1.1313, 0.4237, 0.1208],
[ 1.1995, 0.8043, 0.4396, 0.0699]])

它们的任何功能是火炬提供的吗?提前谢谢。

可能这样做会起作用:

diffs_from_one = torch.abs(tensor - 1)
indices = torch.argsort(diffs_from_one)
sorted_tensor = tensor[indices]

注意我不是特别熟悉PyTorch。但总的想法是,使用argsort来找到应该对原始张量排序的指标,似乎是可行的。

相关内容

  • 没有找到相关文章

最新更新