删除PyTorch张量中特定维度上包含nan的所有索引



我有一个torch张量,其形式如下:

a = torch.tensor([[[2,3],
[3,4],
[4,5]],
[[3,6],
[6,2],
[2,-1]],
[[float('nan'), 1],
[2,3], 
[3,2]])

我想返回另一个移除了nan的张量,但也移除了同一维度上的所有项。所以期待

a_clean =    torch.tensor([[[3,4],
[4,5]],
[[6,2],
[2,-1]],
[[2,3], 
[3,2]])

对如何实现这一点有什么想法吗?

这可以使用Tensor.isnanTensor.any和一些创造性的索引来完成:

>>> a[:, ~a.isnan().any(dim=2).any(dim=0), :]
tensor([[[ 3.,  4.],
[ 4.,  5.]],
[[ 6.,  2.],
[ 2., -1.]],
[[ 2.,  3.],
[ 3.,  2.]]])

请注意,您试图删除维度1上的条目,因此索引发生在维度1上。将isnan()的结果在除1之外的每个维度上进行约简,可以告诉我们1维度中哪些索引包含NaN值。把这些放在一起就得到了上面的表达式。

最新更新