在DataLoader样本中循环时,len()与.size(0)



我在github上看到了这个(这里的片段(:

(...)
for epoch in range(round):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, device=device)
(...)

batch_size = len(data[0])代替batch_size = real_cpu.size(0)会产生同样的效果吗?(或者至少使用batch_size = len(real_cpu)?(我之所以这么问,是因为在循环for (X, y) in dataloader:等过程中显示训练进度时,官方PyTorch教程包含了len(X)。所以我想知道这两种方法在显示"当前"批次中的"样本"数量时是否等效。

如果使用批量大小为第一维度的数据,则可以将real_cpu.size(0)len(real_cpu)len(data[0])进行交换。然而,当使用像LSTM这样的一些模型时,您可以在第二维度上使用批量大小,在这种情况下,您不能使用len,而是使用real_cpu.size(1),例如

最新更新