如何保存/加载一个模型检查点与Pytorch几个损失?



使用Ubuntu 20.04, Pytorch 1.10.1.

我正在尝试解决一个具有变压器架构和多嵌入的音乐生成任务,用于处理具有多个特征的令牌。

在每次训练迭代中,我必须计算每个标记特征的损失并将其存储在一个向量中,然后我假设我应该在检查点中存储包含所有这些特征的向量(或类似的东西),而不是我现在所做的保存总损失。我想知道如何在检查点存储所有损失(能够在加载时保持训练),或者如果它根本不需要。

epoch循环:

for epoch in range(0, epochs):

print('Epoch: ', epoch)

loss = trfrmr.train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
loss_train.append(loss)

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'loss': loss,
}, "model_pop909_checkpoint.pth")

训练循环:

for batch_num, batch in enumerate(dataloader):
time_before = time.time()
opt.zero_grad()
x = batch[0].to(get_device())
tgt = batch[1].to(get_device())
# x is the input sequence (N,T,Z), that should be input into the transformer forward function as (T,N,Z)
y = model(x.permute(1, 0, 2))
# tgt is the real output sequence, of shape (N,T,Z), T is sequence length, N batch size, Z the different token types
# y are the output logits, is a list of Z tensors of shape (T,N,C*) where C is the vocabulary size, and will vary depending on the token type (pitch, velocity etc...)
losses = []
for j in range(LEN_VOCAB):
aux_loss = loss.forward(y[j].permute(1, 2, 0),
tgt[..., j])  # shapes (N,C,T) and (N,T), see Pytorch cross-entropy for details
losses.append(aux_loss)
losses_sum = sum(losses)  # here we sum, but we could also have mean for instance
losses_sum.backward()
opt.step()
if lr_scheduler is not None:
lr_scheduler.step()
lr = opt.param_groups[0]['lr']

loss_hist.append(losses_sum)
if batch_num == num_iters:
break

据我所知,你的损失函数没有自定义的可学习参数;每次模型迭代都会重新计算。因此,除了保存它的历史外,没有必要保存它的价值;不需要从检查点继续训练。

问题是,当再次加载模型时,我没有正确地这样做(没有加载优化器参数,但只有模型参数)。现在在我的代码中,在循环的开始,我这样做:

if loaded:
print('Loading model and optimizer...')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
opt.load_state_dict(checkpoint['optimizer_state_dict'])
print('Loaded succesfully!')

我还加载了epoch:

epoch = 0
if loaded:
print('Loading epoch value...')
epoch = checkpoint['epoch'] 
print('Loaded succesfully!')

这个答案被发布为一个编辑的问题如何保存/加载一个模型检查点在Pytorch几个损失?由OP Enrique Vilchez Campillejo在CC by - sa4.0下编写。

最新更新