我构建了一个数据集,在那里我对正在加载的图像进行各种检查。然后,我将这个数据集传递给数据加载器。
在我的 DataSet 类中,如果图片未通过检查,我将样本返回为 None,并且我有一个自定义collate_fn函数,该函数从检索到的批处理中删除所有 Nones 并返回剩余的有效样本。
但是,此时返回的批处理的大小可能不同。有没有办法告诉collate_fn继续采购数据,直到批量大小达到一定长度?
class DataSet():
def __init__(self, example):
# initialise dataset
# load csv file and image directory
self.example = example
def __getitem__(self,idx):
# load one sample
# if image is too dark return None
# else
# return one image and its equivalent label
dataset = Dataset(csv_file='../', image_dir='../../')
dataloader = DataLoader(dataset , batch_size=4,
shuffle=True, num_workers=1, collate_fn = my_collate )
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
# I want len(G) = 4
# so how to sample another dataset entry?
return torch.utils.data.dataloader.default_collate(batch)
有 2 个技巧可用于解决问题,选择一种方法:
通过使用原始批处理样品快速选项:
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
diff = len_batch - len(batch)
for i in range(diff):
batch = batch + batch[:diff]
return torch.utils.data.dataloader.default_collate(batch)
否则,只需随机从数据集加载另一个样本更好的选择:
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # source all the required samples from the original dataset at random
diff = len_batch - len(batch)
for i in range(diff):
batch.append(dataset[np.random.randint(0, len(dataset))])
return torch.utils.data.dataloader.default_collate(batch)
这对我有用,因为有时即使是那些随机值也是 None。
def my_collate(batch):
len_batch = len(batch)
batch = list(filter(lambda x: x is not None, batch))
if len_batch > len(batch):
db_len = len(dataset)
diff = len_batch - len(batch)
while diff != 0:
a = dataset[np.random.randint(0, db_len)]
if a is None:
continue
batch.append(a)
diff -= 1
return torch.utils.data.dataloader.default_collate(batch)
[编辑] 从下面截取的代码的更新版本可以在这里找到 https://github.com/project-lighter/lighter/blob/main/lighter/utils/collate.py
感谢Brian Formento询问并给出有关如何解决它的想法。如前所述,用新示例替换不良示例的最佳选项有两个问题:
- 新采样的示例也可能已损坏;
- 数据集不在范围内。
这是它们的解决方案 - 问题 1 通过递归调用解决,问题 2 通过创建整理函数的部分函数并固定数据集来解决。
import random
import torch
def collate_fn_replace_corrupted(batch, dataset):
"""Collate function that allows to replace corrupted examples in the
dataloader. It expect that the dataloader returns 'None' when that occurs.
The 'None's in the batch are replaced with another examples sampled randomly.
Args:
batch (torch.Tensor): batch from the DataLoader.
dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
Specify it with functools.partial and pass the resulting partial function that only
requires 'batch' argument to DataLoader's 'collate_fn' option.
Returns:
torch.Tensor: batch with new examples instead of corrupted ones.
"""
# Idea from https://stackoverflow.com/a/57882783
original_batch_len = len(batch)
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
filtered_batch_len = len(batch)
# Num of corrupted examples
diff = original_batch_len - filtered_batch_len
if diff > 0:
# Replace corrupted examples with another examples randomly
batch.extend([dataset[random.randint(0, len(dataset)-1)] for _ in range(diff)])
# Recursive call to replace the replacements if they are corrupted
return collate_fn_replace_corrupted(batch, dataset)
# Finally, when the whole batch is fine, return it
return torch.utils.data.dataloader.default_collate(batch)
但是,您不能将其直接传递给DataLoader
,因为整理函数应该只有一个参数 -batch
。为了实现这一点,我们使用指定的数据集创建一个分部函数,并将分部函数传递给DataLoader
。
import functools
from torch.utils.data import DataLoader
collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn)
对于任何希望即时拒绝训练示例的人来说,与其使用技巧来解决数据加载器collate_fn的问题,不如简单地使用 IterableDataset 并编写如下__iter__和__next__函数
def __iter__(self):
return self
def __next__(self):
# load the next non-None example
为什么不使用 __ get_item__ 方法在数据集类中解决这个问题呢? 与其在数据不好时返回 None ,不如递归请求不同的随机索引。
class DataSet():
def __getitem__(self, idx):
sample = load_sample(idx)
if is_no_good(sample):
idx = np.random.randint(0, len(self)-1)
sample = self[idx]
return sample
这样,您就不必处理不同大小的批次。
对于"快速"选项,它有问题。以下是固定版本。
def my_collate(batch):
len_batch = len(batch) # original batch length
batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
diff = len_batch - len(batch)
batch = batch + batch[:diff] # assume diff < len(batch)
return torch.utils.data.dataloader.default_collate(batch)