我正在处理多个csv文件,每个文件都包含多个1D数据。我有大约9000个这样的文件,总的组合数据大约是40GB。
我写了一个这样的数据加载器:
class data_gen(torch.utils.data.Dataset):
def __init__(self, files):
self.files = files
my_data = np.genfromtxt('/data/'+files, delimiter=',')
self.dim = my_data.shape[1]
self.data = []
def __getitem__(self, i):
file1 = self.files
my_data = np.genfromtxt('/data/'+file1, delimiter=',')
self.dim = my_data.shape[1]
for j in range(my_data.shape[1]):
tmp = np.reshape(my_data[:,j],(1,my_data.shape[0]))
tmp = torch.from_numpy(tmp).float()
self.data.append(tmp)
return self.data[i]
def __len__(self):
return self.dim
我将整个数据集加载到数据加载器的方式就像通过for
循环:
for x_train in tqdm(train_files):
train_dl_spec = data_gen(x_train)
train_loader = torch.utils.data.DataLoader(
train_dl_spec, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
for data in train_loader:
但这项工作进展非常缓慢。我想知道我是否可以将所有数据存储在一个文件中,但我没有足够的RAM。那么有办法绕过它吗?
如果有办法,请告诉我。
我以前从未使用过pytorch,我承认我真的不知道发生了什么。尽管如此,我几乎可以肯定你用错了Dataset
。
据我所知,数据集是所有数据的抽象,每个索引都返回一个样本。假设9000个文件中的每一个都有10行(示例(,21行表示第三个文件和第二行(使用0索引(。
因为你有太多的数据,你不想把所有的东西都加载到内存中。因此,数据集应该只获取一个值,而DataLoader会创建一批值。
几乎可以肯定的是,有些优化可以应用于我所做的工作,但也许这可以让你开始。我用以下文件创建了目录csvs
:
❯ cat csvs/1.csv
1,2,3
2,3,4
3,4,5
❯ cat csvs/2.csv
21,21,21
34,34,34
66,77,88
然后我创建了这个数据集类。它将一个目录作为输入(所有CSV都存储在其中(。那么存储在内存中的唯一东西就是每个文件的名称和它的行数。当请求一个项时,我们会找出哪个文件包含该索引,然后返回该行的张量。
通过只遍历文件,我们从不将文件内容存储在内存中。不过,这里的改进不是迭代文件列表以找出哪一个是相关的,也不是在访问连续索引时使用生成器和状态。
(因为在访问索引8时进行访问,在一个10行的文件中,我们对前7行进行了无用的迭代,这是我们无能为力的。但在访问索引9时,最好计算出我们可以只返回下一行,而不是再次迭代前8行。(
import numpy as np
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
@lru_cache()
def get_sample_count_by_file(path: Path) -> int:
c = 0
with path.open() as f:
for line in f:
c += 1
return c
class CSVDataset:
def __init__(self, csv_directory: str, extension: str = ".csv"):
self.directory = Path(csv_directory)
self.files = sorted((f, get_sample_count_by_file(f)) for f in self.directory.iterdir() if f.suffix == extension)
self._sample_count = sum(f[-1] for f in self.files)
def __len__(self):
return self._sample_count
def __getitem__(self, idx):
current_count = 0
for file_, sample_count in self.files:
if current_count <= idx < current_count + sample_count:
# stop when the index we want is in the range of the sample in this file
break # now file_ will be the file we want
current_count += sample_count
# now file_ has sample_count samples
file_idx = idx - current_count # the index we want to access in file_
with file_.open() as f:
for i, line in enumerate(f):
if i == file_idx:
data = np.array([float(v) for v in line.split(",")])
return torch.from_numpy(data)
现在我们可以按照我认为的意图使用DataLoader:
dataset = CSVDataset("csvs")
loader = DataLoader(dataset, batch_size=4)
pprint(list(enumerate(loader)))
"""
[(0,
tensor([[ 1., 2., 3.],
[ 2., 3., 4.],
[ 3., 4., 5.],
[21., 21., 21.]], dtype=torch.float64)),
(1, tensor([[34., 34., 34.],
[66., 77., 88.]], dtype=torch.float64))]
"""
您可以看到这正确地返回了一批批数据。您可以处理每个批次并只将该批次存储在内存中,而不是打印出来。
有关更多信息,请参阅文档:https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html#part-3-the-dataloader