如何处理Unet模型训练中的NotImplementedError ?


def train_fn(data_loader, model, optimizer):
model.train()
total_loss = 0.0
for images, masks in tqdm(data_loader):
images = images.to(DEVICE)
masks = masks.to(DEVICE)
optimizer.zero_grad()
logits, loss = model(images,masks)
loss.backward()
optimizer.step()
total_loss += loss.item()

return total_loss/ len(data_loader)

def eval_fn(data_loader, model):
model.eval()
total_loss = 0.0
with torch.no_grad():
for images, masks in tqdm(data_loader):
images = images.to(DEVICE)
masks = masks.to(DEVICE)
logits, loss = model(images,masks)

total_loss += loss.item()

return total_loss/ len(data_loader)
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
best_valid_loss = np.Inf
for i in range(EPOCHS):

train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)
if valid_loss < best_valid_loss:
torch.save(model.state_dict(), 'best_model.pt')
print("SAVED_MODEL")
best_valid_loss = valid_loss

print(f"Epoch: {i+1} Train_loss: {Train_loss} Valid_loss: {Valid_loss}")

当我尝试训练模型时,我得到以下错误:

0%| | 0/15 [00:00

NotImplementedError 回溯(最近调用最后)在()45——比;6 train_loss = train_fn(trainloader, model, optimizer)valid_loss = eval_fn(validloader, model)8

2帧/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)199注册的钩子,而后者默默地忽略它们。200年""——比;201 #引发NotImplementedError202203

NotImplementedError:

我该如何处理?

查看链接你提供的评论,你的模型定义是这样的:

class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel,self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode = 'binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits,masks)
return logits, loss1 + loss2
return logits

如果你仔细观察,你会发现forward()有一个不稳定的额外缩进,使其成为__init__()内部的一个内部函数,而不是SegmentationModel的一个方法。把它往左移一点,它应该可以正常工作:

class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel,self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode = 'binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits,masks)
return logits, loss1 + loss2
return logits

相关内容

  • 没有找到相关文章

最新更新