我试图实现以下函数来保存model_state检查点:
def train_epoch(self):
for epoch in tqdm.trange(self.epoch, self.max_epoch, desc='Train Epoch', ncols=100):
self.epoch = epoch # increments the epoch of Trainer
checkpoint = {} # fixme: here checkpoint!!!
# model_save_criteria = self.model_save_criteria
self.train()
if epoch % 1 == 0:
self.validate(checkpoint)
checkpoint_latest = {
'epoch': self.epoch,
'arch': self.model.__class__.__name__,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optim.state_dict(),
'model_save_criteria': self.model_save_criteria
}
checkpoint['checkpoint_latest'] = checkpoint_latest
torch.save(checkpoint, self.model_pth)
之前我只运行了一个for循环:
train_states = {}
for epoch in range(max_epochs):
running_loss = 0
time_batch_start = time.time()
model.train()
for bIdx, sample in enumerate(train_loader):
...
train...
validation...
train_states_latest = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'model_save_criteria': chosen_criteria}
train_states['train_states_latest'] = train_states_latest
torch.save(train_states, FILEPATH_MODEL_SAVE)
是否有方法启动checkpoint={}
并在每个循环中更新它?或者每个时期的checkpoint={}
都可以因为模型本身持有state_dict()
。只是我每次都重写检查点。
您可以通过简单地更改FILEPATH_MODEL_SAVE路径并让该路径包含epoch或迭代数的信息来避免覆盖检查点。例如(以原始代码为例),
train_states = {}
for epoch in range(max_epochs):
running_loss = 0
time_batch_start = time.time()
model.train()
for bIdx, sample in enumerate(train_loader):
...
train...
validation...
train_states_latest = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'model_save_criteria': chosen_criteria}
train_states['train_states_latest'] = train_states_latest
# This is the code you can add
FILEPATH_MODEL_SAVE = "Epoch{}batch{}model_weights.pth".format(epoch, bIdx)
torch.save(train_states, FILEPATH_MODEL_SAVE)
上面这个新的代码火炬。保存您避免覆盖检查点
Sarthak