我正在处理一个二进制分类问题。我有大约150万个数据点,特征空间的维度是100万。该数据集存储为稀疏阵列,密度约为0.0001。对于这篇文章,我将限制范围,假设模型是一个浅前馈神经网络,也假设维度已经优化(因此不能降低到100万以下(。从这些数据中创建小批量以提供给网络的简单方法将花费大量时间(例如,从输入数组的torch.sparse.FloatTensor
表示创建TensorDataset
(映射样式(并将DataLoader
包裹在其周围的基本方法意味着大约20秒可以将32个小批量数据提供给网络,而不是大约0.1秒来执行实际训练(。我正在寻找加快速度的方法。
我尝试过的
- 我最初认为,在
DataLoader
的每次迭代中,从如此大的稀疏阵列中读取数据是计算密集型的,所以我将这个稀疏阵列分解为更小的稀疏阵列 - 为了让
DataLoader
以迭代的方式从这些多个稀疏阵列中读取,我用IterableDataset
替换了DataLoader
中的映射样式数据集,并将这些较小的稀疏阵列流式传输到这个IterableDataset中,如下所示:
from itertools import chain
from scipy import sparse
class SparseIterDataset(torch.utils.data.IterableDataset):
def __init__(self, fpaths):
super(SparseIter).__init__()
self.fpaths = fpaths
def read_from_file(self, fpath):
data = sparse.load_npz(fpath).toarray()
for d in data:
yield torch.Tensor(d)
def get_stream(self, fpaths):
return chain.from_iterable(map(self.read_from_file, fpaths))
def __iter__(self):
return self.get_stream(self.fpaths)
通过这种方法,我能够将时间从20秒左右的天真基本情况降低到32秒左右的每小批0.2秒。然而,考虑到我的数据集有大约150万个样本,这仍然意味着要花很多时间来通过数据集。(作为比较,尽管这有点像苹果和桔子,但在scikit learn上对原始稀疏阵列进行逻辑回归,在整个数据集中每次迭代大约需要约6秒。使用pytorch,按照我刚才概述的方法,在一个时期内加载所有的迷你批次需要约3000秒(
我知道但尚未尝试的一件事是通过在DataLoader
中设置num_workers
参数来使用多进程数据加载。不过,我相信在可迭代风格的数据集的情况下,这也有其自身的局限性。此外,即使是10倍的加速也意味着在加载小批量时每个历元大约300秒。我觉得我太慢了!您是否可以建议其他方法/改进/最佳实践?
非稀疏形式的数据集将是1.5M x 1M x 1字节=1.5TB作为uint8,或1.5M x 1M x 4字节=6TB作为float32。在现代CPU上,简单地从内存到CPU读取6TB可能需要5-10分钟(取决于体系结构(,从CPU到GPU的传输速度会比这慢一点(PCIe上的NVIDIA V100理论上为32GB/s(。
方法:
-
单独对所有内容进行基准测试-例如在jupyter 中
%%timeit数据=稀疏.load_npz(fpath(.toarray((
%%timeit densit=data.toarray((#取消稀疏以进行比较
%%timeit t=torc.tensor(数据(#可能与上方的线大致相同
同时打印出所有内容的形状和数据类型,以确保它们符合预期。我还没有试过运行你的代码,但我很确定(a(sparse.load_npz非常快,不太可能成为瓶颈,但(b(torch.tensor(data(产生了一个密集的张量,在这里也很慢
- 使用torch.sparse。我认为torch稀疏张量在大多数情况下可以用作正则张量。你必须做一些数据准备才能从scipy.sparse转换为torch.s稀疏:
稀疏张量表示为一对稠密张量:值和索引的2D张量。稀疏张量可以通过提供这两个张量来构造,以及稀疏张量的大小
您提到了torch.sparse.FloatTensor
,但我很确定您的代码中没有生成稀疏张量-没有理由期望这些张量只通过将scipy.sparse数组传递给常规张量构造函数来构造,因为它们通常不是这样生成的。
如果你找到了一个好方法,我建议你把它作为一个项目或git发布在github上,这将非常有用。
- 如果torch.s稀疏不起作用,请考虑其他方法,要么仅在GPU上将数据转换为密集数据,要么避免完全转换数据
另请参阅:https://towardsdatascience.com/sparse-matrices-in-pytorch-be8ecaccae6https://github.com/rusty1s/pytorch_sparse