为什么DataLoader返回与batch_size长度不同的列表?



我正在编写一个定制的数据加载器,而返回的值使我感到困惑。

import torch
import torch.nn as nn
import numpy as np
import torch.utils.data as data_utils
class TestDataset:
def __init__(self):
self.db = np.random.randn(20, 3, 60, 60)
def __getitem__(self, idx):
img = self.db[idx]
return img, img.shape[1:]
def __len__(self):
return self.db.shape[0]

if __name__ == '__main__':
test_dataset = TestDataset()
test_dataloader = data_utils.DataLoader(test_dataset,
batch_size=1,
num_workers=4,
shuffle=False, 
pin_memory=True
)
for i, (imgs, sizes) in enumerate(test_dataloader):
print(imgs.size())  # torch.Size([1, 3, 60, 60])
print(sizes)  # [tensor([60]), tensor([60])]
break

为什么"sizes"返回长度为2的列表?我觉得应该是"火炬"。大小([1,2])";表示图像的高度和宽度(1 batch_size)。

更进一步,是否返回列表的长度batch_size相同? 如果我想要得到大小,我必须写"sizes = [sizes[0][0].item(), sizes[1][0].item()]"这让我很困惑。

感谢您的宝贵时间。

这是由collate_fn函数及其默认行为引起的。它的主要目的是简化批量制备过程。因此,您可以自定义批量准备过程来更新此函数。如文档collate_fn中所述,它会自动将NumPy数组和Python数值转换为PyTorch张量,并保留数据结构。所以在你的例子中它返回[张量([60]),张量([60])]。在许多情况下,你返回带有标签的图像作为张量(而不是图像的大小),并将其前馈到神经网络。我不知道为什么在枚举时返回图像大小,但是您可以添加自定义collate_fn作为:

def collate_fn(data):
imgs, lengths = data[0][0],data[0][1]    
return torch.tensor(imgs), torch.tensor([lengths])

那么你应该将它设置为DataLoader的参数:

test_dataloader = DataLoader(test_dataset,
batch_size=1,
num_workers=4,
shuffle=False, 
pin_memory=True, collate_fn=collate_fn
)

那么你可以循环为:

for i, (imgs, sizes) in enumerate(test_dataloader):
print(imgs.size())
print(sizes)  
print(sizes.size())  
break

,输出如下:

torch.Size([3, 60, 60])
tensor([[60, 60]])
torch.Size([1, 2])

毕竟,我想再加一点,你不应该只返回self.db。shape[0] inlen函数。在这种情况下,你的批量大小是1,这是可以的;但是,当批大小改变时,它不会返回#batches的真实值。你可以这样更新你的类:

class TestDataset:
def __init__(self, batch_size=1):
self.db = np.random.randn(20, 3, 60, 60)
self._batch_size = batch_size

def __getitem__(self, idx):
img = self.db[idx]
return img, img.shape[1:]
def __len__(self):
return self.db.shape[0]/self._batch_size

为什么"返回长度为2的列表?

返回从db中切片的单个元素的切片shape。下面的代码片段应该更清楚:

import numpy as np
db = np.random.randn(20, 3, 60, 60)
img = db[0]
img.shape # (3, 60, 60)
img.shape[1:] # (60, 60)

此外,返回列表的长度是否与batch_size吗?

你为什么要从DataLoader返回这个?从Dataset返回image:

def __getitem__(self, idx):
return self.db[idx]

对于batch_size=12,您将得到形状(12, 3, 60, 60)的输出。你可以从这个示例中得到形状,不要在Dataset中创建它,没有意义。

相关内容

  • 没有找到相关文章

最新更新