将Pytorch .pth模型的权重保存为.txt或.json



我试图将pytorch模型的权重保存为.txt或.json。当写入。txt文件时,

#import torch
model = torch.load("model_path")
string = str(model)
with open('some_file.txt', 'w') as fp:
fp.write(string)

我得到一个文件,其中不是所有的权重被保存,即有省略号在整个文本文件。我不能把它写到JSON,因为模型有张量,这不是JSON序列化[除非有一种方式,我不知道?如何将.pth文件中的权重保存为某种格式,使信息不会丢失,并且可以很容易地看到?

感谢

有点晚了,但希望这能有所帮助。这是你存储它的方式:

import torch
from torch.utils.data import Dataset
from json import JSONEncoder
import json
class EncodeTensor(JSONEncoder,Dataset):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().detach().numpy().tolist()
return super(NpEncoder, self).default(obj)
with open('torch_weights.json', 'w') as json_file:
json.dump(model.state_dict(), json_file,cls=EncodeTensor)

考虑到存储的值是list类型,所以当你要使用权重时,你必须使用torch.Tensor(list)

当你做str(model.state_dict())时,它递归地使用str方法的元素包含。所以问题是如何构建单个元素字符串表示。您应该增加以单个字符串表示方式打印的行数限制:

torch.set_printoptions(profile="full")

看下面的区别:

import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])

张量目前不支持JSON序列化。

最新更新