我想访问和编辑torchvision模块中的各个模块,并调整输入。
我知道你可以编辑这样的子模块:
import torchvision
resnet18 = torchvision.models.resnet18()
print(resnet18._modules['conv1'])
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
但我想创建一个自定义的Net(nn.Module(类,这样我以后就可以添加额外的层:
class Sonar(resnet18):
pass
抛出错误:
----> 1 class Sonar(resnet18):
2 pass
/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in __init__(self, block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
142 self.relu = nn.ReLU(inplace=True)
143 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
--> 144 self.layer1 = self._make_layer(block, 64, layers[0])
145 self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
146 dilate=replace_stride_with_dilation[0])
/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in _make_layer(self, block, planes, blocks, stride, dilate)
176 self.dilation *= stride
177 stride = 1
--> 178 if stride != 1 or self.inplanes != planes * block.expansion:
179 downsample = nn.Sequential(
180 conv1x1(self.inplanes, planes * block.expansion, stride),
AttributeError: 'str' object has no attribute 'expansion'
使用AlexNet 重试
alexnet = torchvision.models.AlexNet()
class Sonar(alexnet):
pass
抛出错误:
1 alexnet = torchvision.models.AlexNet()
----> 2 class Sonar(alexnet):
3 pass
TypeError: __init__() takes from 1 to 2 positional arguments but 4 were given
以下操作应该很好:
import torch
import torchvision
class Sonar(torch.nn.Module):
def __init__(self):
super().__init__()
self.ins = torchvision.models.resnet18(pretrained=True)
self.fc1 = torch.nn.Linear(1000, 1) #adding layers
def forward(self, x):
out = self.ins(x)
out = self.fc1(out)
return out
def run():
return Sonar()
net = run()
print(net(torch.ones(1,3,224,224))) #testing
您可以导入python文件并在本地编辑它吗?