Pytorch快速启动调用model.eval(),但不调用model.train()



在Pytorch快速启动教程中,代码在评估/测试期间使用model.eval(),但在训练期间不调用model.train()

根据这一点和来源,BatchNormDropout等模块需要知道模型是处于训练模式还是评估模式。教程中的模型不使用任何这样的模块,因此它运行到收敛。是我遗漏了什么,还是Pytorch的第一个教程实际上有一个逻辑错误?

培训:

def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)

# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)

# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

您可以看到在上面的代码中没有model.train()

测试:

def test(dataloader, model):
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} n")

在第二行,有一个model.eval()

训练循环:

epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model)
print("Done!")

此循环调用train()test()方法,而不调用任何model.train()。因此在test()的第一次调用之后;评价";模式如果我们在模型中添加一个BatchNorm,我们就会遇到一个很难找到的错误。

主要问题:

在训练过程中总是呼叫model.train(),在评估/测试过程中呼叫model.eval(),这是一种好的做法吗?

正如教程的描述所说,QuickStart旨在"以快速熟悉PyTorch的API";而不是理解所有的概念。

我认为作者希望尽可能缩短快速启动时间。如果你真的想学习PyTorch,你可能会做完整的教程。那么你的问题最晚将在";优化模型参数";(官方教程的一部分(

def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
# Set the model to training mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.train()
...

但你是对的,一个简短的提示("你稍后会了解更多"(将是很好的

相关内容

最新更新