如何在pytorch中处理运行时错误时对矩阵求逆进行矢量化



我需要在pytorch中反转一些矩阵。然而,有些矩阵是不可逆的,这导致代码抛出运行时错误如下,

matrices = torch.randn([5,3,3])
matrices[[2,3]] = torch.zeros([3,3])
inverses = torch.inverse(matrices)
RuntimeError: inverse_cpu: For batch 2: U(1,1) is zero, singular U.

我有一种应对这种情况的后备技巧。然而,我不知道哪个矩阵会产生错误。目前,我已经用非矢量化版本替换了代码,但它已经成为一个瓶颈。

有没有一种方法可以在不放弃矢量化的情况下处理这个问题?

我能想到的最好的方法是首先计算每个矩阵的行列式,然后计算那些具有abs(det)>0的矩阵的逆。

matrices = torch.randn([5,3,3])
matrices[[2,3]] = torch.zeros([3,3])
determinants = torch.det(matrices)
inverses = torch.inverse(matrices[determinants.abs()>0.])

您必须处理奇异矩阵的移除,但这应该不会太难,因为您已经从determinants.abs()==0.中获得了这些矩阵的索引值。这样可以使反转保持矢量化。

相关内容

  • 没有找到相关文章