这就是"model_zoo.load_url"和"state_dict"之间的区别


model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet 
initial_weight_1 = local_model.state_dict() 
for key in initial_weight.keys():
if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
initial_weight[key] = initial_weight_1[key] 
local_model.load_state_dict(initial_weight)

我不明白这句话;initial_ weight[key]=initial_;

你能告诉我为什么我们需要这样做吗?

感谢

函数torch.utils.model_zoo.load_url将从给定的URL加载序列化的torch对象。在这种特殊情况下,使用的URL承载ResNet18网络的模型权重字典。

因此,initial_weight是包含预训练的ResNet18的权重的字典,而initial_weight_1是由resnet18初始化的存储器中的当前模型resnet的权重的词典。

如果满足key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':条件,以下几行将遍历resnet模型的各个层,并复制从该URL加载的权重。

相关内容

  • 没有找到相关文章

最新更新