torch.utils.data.DataLoader的迭代是如何工作的?



这里是输出的图像

trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)
import torch.optim as optim
model = NEURAL_NETWORK()
optimizer = optim.SGD(model.parameters(), lr = 0.03)
costfx = nn.CrossEntropyLoss()

def train_model(model,batches):
times = 0
for i in range(batches):
accuracy = 0
for image, label in trainloader:

optimizer.zero_grad()

output = model(image)
times += 1

top_p, top_class = output.topk(1, dim=1)
equals = top_class == label.view(*top_class.shape)

accuracy += torch.mean(equals.type(torch.FloatTensor))

loss = costfx(output,label)
loss.backward()
optimizer.step()
print("Batch number:{}, train_loss is: {}, accuracy: {}"
.format(i+1,loss, accuracy/len(trainloader)), times)

由于batch_size为64,我期望在trainload对象上迭代一次将返回times=64,但实际上返回times=938。有人能解释一下原因吗?

PyTorch中数据加载器的工作方式非常简单。您定义了一个数据集,并用数据加载器包装它。此实用程序函数用于从底层数据集中采样元素。一个简单的用法是简单地设置batch_size选项:每批元素的数量。当您对数据加载器进行迭代时,您将在每个迭代循环中访问单个批处理。每个批包含一定数量的batch_size元素,这些元素在init上定义。

在您的示例中,如果将trainloader定义为DataLoader(trainset,batch_size=64,shuffle=True),则在循环中:

for image, label in trainloader:
pass

imagelabel都将对应于单个批,每个64个元素(64个图像及其相应的标签)。这意味着您不必遍历range(batches)。但是,您可以在数据加载器上迭代多次,这实际上对应于epoch:

for epoch in range(epochs):
for image, label in trainloader:
optimizer.zero_grad()
output = model(image)
# ...