如何重命名 Pytorch 对象名称?



My Pytorch Model:

EfficientDet(
(backbone): EfficientNetFeatures(
(conv_stem): Conv2d(4, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
...
...

有没有办法backbone对象重命名为其他名称?

我们可以使用以下函数重命名实例的属性。

def rename_attribute(obj, old_name, new_name):
obj._modules[new_name] = obj._modules.pop(old_name)

class EfficientNetFeatures(nn.Module):
def __init__(self):
super(EfficientNetFeatures, self).__init__()
self.conv_stem = nn.Conv2d(4, 48, kernel_size=(3, 3),
stride=(2, 2), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(48, eps=0.001, momentum=0.1,
affine=True, track_running_stats=True)

class EfficientDet(nn.Module):
def __init__(self):
super(EfficientDet, self).__init__()
self.backbone = EfficientNetFeatures()

model = EfficientDet()
print(model)
rename_attribute(model, 'backbone', 'newname')
print(model)

输出:

EfficientDet(
(backbone): EfficientNetFeatures(
(conv_stem): Conv2d(4, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)
EfficientDet(
(newname): EfficientNetFeatures(
(conv_stem): Conv2d(4, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)

相关内容

最新更新