我有一个torch
张量,其形式如下:
a = torch.tensor([[[2,3],
[3,4],
[4,5]],
[[3,6],
[6,2],
[2,-1]],
[[float('nan'), 1],
[2,3],
[3,2]])
我想返回另一个移除了nan
的张量,但也移除了同一维度上的所有项。所以期待
a_clean = torch.tensor([[[3,4],
[4,5]],
[[6,2],
[2,-1]],
[[2,3],
[3,2]])
对如何实现这一点有什么想法吗?
这可以使用Tensor.isnan
、Tensor.any
和一些创造性的索引来完成:
>>> a[:, ~a.isnan().any(dim=2).any(dim=0), :]
tensor([[[ 3., 4.],
[ 4., 5.]],
[[ 6., 2.],
[ 2., -1.]],
[[ 2., 3.],
[ 3., 2.]]])
请注意,您试图删除维度1
上的条目,因此索引发生在维度1
上。将isnan()
的结果在除1
之外的每个维度上进行约简,可以告诉我们1
维度中哪些索引包含NaN值。把这些放在一起就得到了上面的表达式。