我正在处理Conv2d(3,3,kernel_size=5, stride=1),我需要设置一些特定的权重为零,并使它们不可更新。例如,如果我输入
model = nn.Conv2d(3,3,kernel_size=5, stride=1)
model.weight.requires_grad = False
一切工作,但它影响整个层。我想这样做:
model = nn.Conv2d(3,3,kernel_size=5, stride=1)
model.weight[0,2].requires_grad = False # this line does not work
model.weight[0,2] = 0 # this line does not work either
它似乎不支持层参数子组的赋值和requires_grad操作。有人已经解决了这个问题吗?
您可以通过访问data
属性
>>> model.weight.data[0, 2] = 0
或使用torch.no_grad
上下文管理器:
>>> with torch.no_grad():
... model.weight[0, 2] = 0
正如您注意到的,您不能为子模块专门设置requires_grad
。因此,给定模块的所有参数元素共享相同的标志,它们要么被更新,要么不被更新。
另一种方法是在调用后,在优化器步骤之前,手动终止该通道的梯度:
>>> model(torch.rand(2, 3, 100, 100)).mean().backward()
>>> model.weight.grad[0, 2] = 0
>>> optim.step()
这样,过滤器n°1上的第三个通道将不会被向后传递更新,并保持在0
。