Pytorch数据加载器如何处理可变大小数据



我的数据集如下。这是第一个项目是用户ID,然后是用户点击的项目集。

0   24104   27359   6684
0   24104   27359
1   16742   31529   31485
1   16742   31529
2   6579    19316   13091   7181    6579    19316   13091
2   6579    19316   13091   7181    6579    19316
2   6579    19316   13091   7181    6579    19316   13091   6579
2   6579    19316   13091   7181    6579
4   19577   21608
4   19577   21608
4   19577   21608   18373
5   3541    9529
5   3541    9529
6   6832    19218   14144
6   6832    19218
7   9751    23424   25067   12606   26245   23083   12606

我定义一个自定义数据集来处理我的点击日志数据。

import torch.utils.data as data
class ClickLogDataset(data.Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.uids = []
        self.streams = []
        with open(self.data_path, 'r') as fdata:
            for row in fdata:
                row = row.strip('n').split('t')
                self.uids.append(int(row[0]))
                self.streams.append(list(map(int, row[1:])))
    def __len__(self):
        return len(self.uids)
    def __getitem__(self, idx):
        uid, stream = self.uids[idx], self.streams[idx]
        return uid, stream

然后,我使用数据加载程序从数据中检索迷你批次进行培训。

from torch.utils.data.dataloader import DataLoader
clicklog_dataset = ClickLogDataset(data_path)
clicklog_data_loader = DataLoader(dataset=clicklog_dataset, batch_size=16)
for uid_batch, stream_batch in stream_data_loader:
    print(uid_batch)
    print(stream_batch)

上面的代码返回与我预期的不同,我希望stream_batch成为长度16类型整数的2D张量。但是,我得到的是长度16的一维张量的列表,该列表只有一个元素,如下所示。为什么是?

#stream_batch
[tensor([24104, 24104, 16742, 16742,  6579,  6579,  6579,  6579, 19577, 19577,
        19577,  3541,  3541,  6832,  6832,  9751])]

那么,您如何处理样品长度不同的事实?torch.utils.data.DataLoader具有collate_fn参数,用于将样本列表转换为批处理。默认情况下,它可以对列表进行操作。您可以编写自己的collate_fn,例如0-键入输入,将其截断为预定义的长度或应用您选择的任何其他操作。

这就是我这样做的方式:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length
    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask

然后,我将其传递给数据加载程序类,为collate_fn


在Pytorch论坛上似乎有一个巨大的帖子列表。让我链接到所有这些。他们都有自己的答案和讨论。在我看来,没有一种"标准方法",但是如果有权威参考,请分享。

理想的答案提到

很高兴
  • 效率,例如如果要在colate函数中用火炬在gpu中进行处理,而numpy

那种东西。

列表:

  • https://discuss.pytorch.org/t/how-to-to-to-create-batches-of-a-list-a-list-of-varying-dimension-tensors/50773
  • https://discuss.pytorch.org/t/how-to-to-to-create-a-dataloader-with-with-variable-dariable-size-input/8278
  • https://discuss.pytorch.org/t/using-variable-sied-sized-input-is-padding-required/18131
  • https://discuss.pytorch.org/t/dataloader-for-various-length-length-fenda/6418
  • https://discuss.pytorch.org/t/how-to-do-padding base-on-lengths/24442

桶:-https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284

正如@jatentaki所建议的,我写了自定义的校正功能,而且工作正常。

def get_max_length(x):
    return len(max(x, key=len))
def pad_sequence(seq):
    def _pad(_it, _max_len):
        return [0] * (_max_len - len(_it)) + _it
    return [_pad(it, get_max_length(seq)) for it in seq]
def custom_collate(batch):
    transposed = zip(*batch)
    lst = []
    for samples in transposed:
        if isinstance(samples[0], int):
            lst.append(torch.LongTensor(samples))
        elif isinstance(samples[0], float):
            lst.append(torch.DoubleTensor(samples))
        elif isinstance(samples[0], collections.Sequence):
            lst.append(torch.LongTensor(pad_sequence(samples)))
    return lst
stream_dataset = StreamDataset(data_path)
stream_data_loader = torch.utils.data.dataloader.DataLoader(dataset=stream_dataset,                                                         
                                                            batch_size=batch_size,                                            
                                                        collate_fn=custom_collate,
                                                        shuffle=False)

相关内容

  • 没有找到相关文章

最新更新