我在github上看到了这个(这里的片段(:
(...)
for epoch in range(round):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, device=device)
(...)
用batch_size = len(data[0])
代替batch_size = real_cpu.size(0)
会产生同样的效果吗?(或者至少使用batch_size = len(real_cpu)
?(我之所以这么问,是因为在循环for (X, y) in dataloader:
等过程中显示训练进度时,官方PyTorch教程包含了len(X)
。所以我想知道这两种方法在显示"当前"批次中的"样本"数量时是否等效。
如果使用批量大小为第一维度的数据,则可以将real_cpu.size(0)
与len(real_cpu)
或len(data[0])
进行交换。然而,当使用像LSTM这样的一些模型时,您可以在第二维度上使用批量大小,在这种情况下,您不能使用len
,而是使用real_cpu.size(1)
,例如