权重归一化导致PyTorch中的nan



我使用PyTorch 1.2.0内置的权重归一化。当使用权范数的层的权值接近于0时,权范数运算得到NaN,然后在整个网络中传播。为了解决这个问题,我想在PyTorch权重规范函数中为weight_v的规范添加一个像eps = 1e-6这样的小值。

所以我试图在我的本地计算机上找到这个函数,并在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/nn/utils/weight_norm.py(GitHub代码)找到它,并试图修改它。

我想找到哪个函数在计算权重范数,所以我在每个函数中添加了hi,并发现compute_weight函数在每次向前传递之前被调用。

这个函数调用_weight_norm函数存储在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py(GitHub代码)。当我将print("hi")添加到_weight_norm函数中,但是它没有被打印出来。

那么,通过将eps添加到权重规范中来修改权重规范代码的正确方法是什么?也许可以用最新的PyTorch 1.9.0版本替换我本地计算机上的_weight_norm函数,但不确定如何添加eps

在https://github.com/facebookresearch/multiface/blob/main/models.py#L580找到临时解决方案

所以不要使用常规的nn.conv2d,使用

class Conv2dWN(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2dWN, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
True,
)
self.g = nn.Parameter(torch.ones(out_channels))
def forward(self, x):
wnorm = torch.sqrt(torch.sum(self.weight**2))
return F.conv2d(
x,
self.weight * self.g[:, None, None, None] / wnorm,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)

最新更新