从经过培训的UNet获取编码器



我已经在一些图像上训练了一个UNet模型,但现在,我想提取模型的编码器部分。我的UNet具有以下架构:

UNet(
(conv_final): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
(down_convs): ModuleList(
(0): DownConv(
(conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): DownConv(
(conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(2): DownConv(
(conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(3): DownConv(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(4): DownConv(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(up_convs): ModuleList(
(0): UpConv(
(upconv): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): UpConv(
(upconv): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(2): UpConv(
(upconv): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
(conv1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(3): UpConv(
(upconv): ConvTranspose2d(16, 8, kernel_size=(2, 2), stride=(2, 2))
(conv1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)

我试图通过model.down_convs加载编码器层,但我得到了以下错误:

中的TypeError Traceback(最近一次调用(---->1 res=编码器(train_img(

~/anconda3/envs/work/lib/python3.8/site-packages/torch/nn/modules/module.py调用中(self,*input,**kwargs(548结果=自我_slow_forward(*input,**kwargs(549其他:–>550结果=self.forward(*input,**kwargs(551用于挂接自我_forward_hooks.values((:552 hook_result=钩子(self,input,结果(

TypeError:forward((接受1个位置参数,但为2个提供了

我已经附上了模型,所以你可以试用一下。的重量

请告诉我。

这应该可以工作。

net = UNet(8) # network object having 8 classes
net.load_state_dict(torch.load('PATH'))
print(net) #see the names of the layers of encoder. 
net1 = net.down_convs #as you have named the encoder as down_convs
#net1 is your encoder. 

相关内容

  • 没有找到相关文章

最新更新