Pytorch自定义模块使用现有CNN模块



我想访问和编辑torchvision模块中的各个模块,并调整输入。

我知道你可以编辑这样的子模块:

import torchvision
resnet18 = torchvision.models.resnet18()
print(resnet18._modules['conv1'])
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

但我想创建一个自定义的Net(nn.Module(类,这样我以后就可以添加额外的层:

class Sonar(resnet18):
pass

抛出错误:

----> 1 class Sonar(resnet18):
2     pass
/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in __init__(self, block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
142         self.relu = nn.ReLU(inplace=True)
143         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
--> 144         self.layer1 = self._make_layer(block, 64, layers[0])
145         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
146                                        dilate=replace_stride_with_dilation[0])
/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in _make_layer(self, block, planes, blocks, stride, dilate)
176             self.dilation *= stride
177             stride = 1
--> 178         if stride != 1 or self.inplanes != planes * block.expansion:
179             downsample = nn.Sequential(
180                 conv1x1(self.inplanes, planes * block.expansion, stride),
AttributeError: 'str' object has no attribute 'expansion'

使用AlexNet 重试

alexnet = torchvision.models.AlexNet()
class Sonar(alexnet):
pass

抛出错误:

1 alexnet = torchvision.models.AlexNet()
----> 2 class Sonar(alexnet):
3     pass
TypeError: __init__() takes from 1 to 2 positional arguments but 4 were given

以下操作应该很好:

import torch
import torchvision
class Sonar(torch.nn.Module):
def __init__(self):
super().__init__()
self.ins = torchvision.models.resnet18(pretrained=True)
self.fc1 = torch.nn.Linear(1000, 1) #adding layers
def forward(self, x):
out = self.ins(x)
out = self.fc1(out)
return out
def run():
return Sonar()
net = run()
print(net(torch.ones(1,3,224,224))) #testing

您可以导入python文件并在本地编辑它吗?

最新更新