PyTorch自定义DataLoader与多个csv一起工作



我正在尝试定义一个自定义的PyTorch DataLoader,能够有效地从不同的巨大的读取csv无需将它们加载到内存中。问题的定义如下。为简单起见,假设我有两个csv

1.csv:
1, 2, 3
4, 5, 6
7, 8, 9
2.csv:
10,11,12
13,14,15
16,17,18

为简单起见,我们也假设批大小为1。发生器应产生两个张量:

Tensor_1: [1, 2, 3, 4, 5, 6, 7, 8, 9]
Tensor_2: [10, 11, 12, 13, 14, 15, 16, 17, 18]

这是因为对于每个有效索引,我应用了一个等于2的历史窗口,然后我将样本平坦化。

根据从多个csv文件加载数据的最快方法是什么,我编写了以下代码:

import numpy as np
import pandas as pd
import glob
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
import torch
@lru_cache()
def get_sample_count_by_file(path: Path) -> int:
c = 0
with path.open() as f:
for line in f:
c += 1
return c
class CSVDataset:
def __init__(self, csv_directory: str, extension: str = ".csv"):
self.directory = Path(csv_directory)
self.files = sorted((f, get_sample_count_by_file(f)) for f in self.directory.iterdir() if f.suffix == extension)
self._sample_count = sum(f[-1] for f in self.files)
def __len__(self):
return self._sample_count
def __getitem__(self, idx):
current_count = 0
history_window = 2
my_idx=idx+2
for file_, sample_count in self.files:
if current_count <= my_idx < current_count + sample_count:
break  
current_count += sample_count
file_idx = my_idx - current_count # the index we want to access in file_
if file_idx < 2:
file_idx += 2
with file_.open() as f:
data = []
for i, line in enumerate(f):
if i >= file_idx-history_window and i <= file_idx:
for v in line.split(","):
data.append(float(v))
data = np.array(data)
return torch.from_numpy(data)

dataset = CSVDataset("<PATH CONTAINING CSVs>")
loader = DataLoader(dataset, batch_size=1)
pprint(list(enumerate(loader)))

它完全适用于第一个文件,但当它切换到第二个CSV时存在麻烦(由于索引管理错误,有一些重复)。如何解决这个问题?

如何使用您的CSVDataset为一个csv,然后使用torch.utils.data.ConcatDataset将所有单独的csv数据集连接到一个单一的。Pytorch会为你处理索引,只要每个CSVDataset内的索引是一致的。

最新更新