我试图将模型的状态dict临时存储在一个变量中,并希望稍后将其恢复到我的模型中,但随着模型的更新,该变量的内容会自动更改。
有一个最小的例子:
import torch as t
import torch.nn as nn
from torch.optim import Adam
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(3, 2)
def forward(self, x):
return self.fc(x)
net = Net()
loss_fc = nn.MSELoss()
optimizer = Adam(net.parameters())
weights = net.state_dict()
print(weights)
x = t.rand((5, 3))
y = t.rand((5, 2))
loss = loss_fc(net(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(weights)
我以为这两个输出会是一样的,但我得到了(输出可能会因为随机初始化而改变(
OrderedDict([('fc.weight', tensor([[-0.5557, 0.0544, -0.2277],
[-0.0793, 0.4334, -0.1548]])), ('fc.bias', tensor([-0.2204, 0.2846]))])
OrderedDict([('fc.weight', tensor([[-0.5547, 0.0554, -0.2267],
[-0.0783, 0.4344, -0.1538]])), ('fc.bias', tensor([-0.2194, 0.2856]))])
weights
的内容发生了变化,这太奇怪了。
我也尝试了.copy()
和t.no_grad()
,但它们没有帮助。
with t.no_grad():
weights = net.state_dict().copy()
是的,我知道我可以使用t.save()
保存状态dict,但我只想弄清楚上一个例子中发生了什么。
我正在使用Python 3.8.5
和Pytorch 1.8.1
谢谢你的帮助。
OrderedDict
就是这样工作的。这里有一个更简单的例子:
from collections import OrderedDict
# a mutable variable
l = [1,2,3]
# an OrderedDict with an entry pointing to that mutable variable
x = OrderedDict([("a", l)])
# if you change the list
l[1] = 20
# the change is reflected in the OrderedDict
print(x)
# >> OrderedDict([('a', [1, 20, 3])])
如果你想避免这种情况,你必须进行deepcopy
而不是浅copy
:
from copy import deepcopy
x2 = deepcopy(x)
print(x2)
# >> OrderedDict([('a', [1, 20, 3])])
# now, if you change the list
l[2] = 30
# you do not change your copy
print(x2)
# >> OrderedDict([('a', [1, 20, 3])])
# but you keep changing the original dict
print(x)
# >> OrderedDict([('a', [1, 20, 30])])
由于Tensor
也是可变的,因此在您的情况下也会出现相同的行为。因此,您可以使用:
from copy import deepcopy
weights = deepcopy(net.state_dict())