如何在pytorch中批处理矩阵向量乘法(一个矩阵,多个向量)而不在内存中复制矩阵



我有大小为dn向量和单个d x d矩阵J。我想计算J与每个n向量的n矩阵向量乘积。

为此,我使用pytorch的expand()来获得J广播,但似乎在计算矩阵向量乘积时,pytorch在内存中实例化了一个完整的n x d x d张量。例如以下代码

device = torch.device("cuda:0")
n = 100_000_000
d = 10
x = torch.randn(n, d, dtype=torch.float32, device=device)
J = torch.randn(d, d, dtype=torch.float32, device=device).expand(n, d, d)
y = torch.sign(torch.matmul(J, x[..., None])[..., 0])

提高

RuntimeError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 11.00 GiB total capacity; 3.73 GiB already allocated; 5.69 GiB free; 3.73 GiB reserved in total by PyTorch)

这意味着pytorch不必要地试图为矩阵Jn副本分配空间

如何在不耗尽GPU内存的情况下以矢量化的方式执行此任务(矩阵很小,所以我不想在每次矩阵向量乘法上循环(?

我认为这将解决它:

import torch
x = torch.randn(n, d)
J = torch.randn(d, d) # no need to expand
y = torch.matmul(J, x.T).T

使用表达式进行验证:

Jex = J.expand(n, d, d)
y1 = torch.matmul(Jex, x[..., None])[..., 0]
y = torch.matmul(J, x.T).T
torch.allclose(y1, y) # using allclose for float values
# tensor(True)

最新更新