我正在尝试在 PyTorch 中训练一个深度学习模型,这些模型已经存储到特定维度。我想使用小批量训练我的模型,但小批量大小并没有整齐地划分每个存储桶中的示例数量。
我在上一篇文章中看到的一种解决方案是用额外的空格填充图像(无论是动态还是在训练开始时一次全部(,但我不想这样做。相反,我想允许批量大小在训练期间灵活。
具体来说,如果N
是存储桶中的图像数量,B
是批量大小,那么对于该存储桶,如果B
除以N
,我想获得N // B
批次,否则N // B + 1
批次。最后一批可以包含少于B
个示例。
例如,假设我有索引 [0, 1, ..., 19],包括索引,并且我想使用 3 的批大小。
索引 [0, 9] 对应于存储桶 0 中的图像(形状 (C, W1,H1((索引 [10, 19] 对应于存储桶 1 中的图像(形状 (C, W2, H2((
(所有图像的通道深度都相同(。那么可接受的索引分区将是
batches = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[19]
]
我更愿意分别处理索引为 9 和 19 的图像,因为它们具有不同的尺寸。
浏览 PyTorch 的文档,我找到了生成小批量索引列表的BatchSampler
类。我创建了一个自定义Sampler
类来模拟上述索引的分区。如果有帮助,这是我的实现:
class CustomSampler(Sampler):
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.buckets = self._get_buckets(dataset)
self.num_examples = len(dataset)
def __iter__(self):
batch = []
# Process buckets in random order
dims = random.sample(list(self.buckets), len(self.buckets))
for dim in dims:
# Process images in buckets in random order
bucket = self.buckets[dim]
bucket = random.sample(bucket, len(bucket))
for idx in bucket:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
# Yield half-full batch before moving to next bucket
if len(batch) > 0:
yield batch
batch = []
def __len__(self):
return self.num_examples
def _get_buckets(self, dataset):
buckets = defaultdict(list)
for i in range(len(dataset)):
img, _ = dataset[i]
dims = img.shape
buckets[dims].append(i)
return buckets
但是,当我使用自定义Sampler
类时,我生成以下错误:
Traceback (most recent call last):
File "sampler.py", line 143, in <module>
for i, batch in enumerate(dataloader):
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 263, in __next__
indices = next(self.sample_iter) # may raise StopIteration
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 139, in __iter__
batch.append(int(idx))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'
DataLoader
类似乎期望传递索引,而不是索引列表。
我不应该为此任务使用自定义Sampler
类吗?我还考虑过制作一个自定义collate_fn
传递给DataLoader
,但通过这种方法,我不相信我可以控制允许哪些索引位于同一个迷你批次中。任何指导将不胜感激。
每个样本是否有 2 个网络(必须修复 CNN 内核大小(。如果是,只需将上述custom_sampler
传递给 DataLoader 类的 batch_sampler 参数。这将解决问题。
嗨,由于每批都应该包含相同维度的图像,因此您的CustomSampler
工作正常,需要将其作为参数传递给mx.gluon.data.DataLoader
,关键字batch_sampler
。但是,如文档中所述,请记住这一点:
"如果指定了
batch_sampler
,则不要指定shuffle
、sampler
和last_batch
">