我有一个问题,如何从pytorch数据加载器中获得批量迭代的总数?
以下是训练的通用代码
for i, batch in enumerate(dataloader):
那么,有没有任何方法可以得到";对于循环";?
在我的NLP问题中,迭代的总次数与int(n_train_samples/batch_size(不同。。。
例如,如果我只截断10000个样本的训练数据,并将批大小设置为1024,那么在我的NLP问题中会发生363次迭代。
我想知道如何得到"中的总迭代次数;for循环;。
谢谢。
len(dataloader)
返回批次总数。它取决于数据集的__len__
函数,因此请确保设置正确。
创建数据加载器时有一个附加参数。它被称为drop_last
。
如果为drop_last=True
,则长度为number_of_training_examples // batch_size
。如果是drop_last=False
,则可能是number_of_training_examples // batch_size +1
。
BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)
对于预定义的数据集,您可能会得到许多示例,如:
# number of examples
len(dl_train.dataset)
数据加载器内的正确批次数始终为:
# number of batches
len(dl_train)