Pytorch:CNN在 torch.cat()之后什么也学不到?



我尝试使用像这样的代码在网络中串联变量

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = x.view(x.size(0), -1)
    x= torch.cat((x,angle),1) # from here I concat it.
    x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
    x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
    x = self.fc3(x)

,然后我发现我的网络什么也没学,并始终给ACC约50%。因此,我打印了param.grad,正如我所期望的,它们都是Nan。有人以前遇到过这个东西吗?

我以前没有串联运行代码,并且效果很好。因此,我想这是摩擦所在的地方,并且系统不会抛出任何错误或异常。如果需要其他备份信息,请告诉我。

谢谢。

可能该错误在您提供的代码之外。尝试检查输入中是否有NAN,并检查损失功能是否未导致NAN。

最新更新