加快pytorch代码中循环的不可并行性



我在pytorch中构建了一个网络,经过分析,发现大约90%的工作是在我的一个块中的for循环中完成的。问题是,由于依赖于被mask1屏蔽的先前值,这个循环是不可并行的(见下面的MWE(。

我试着用torch.jit.script编译它,速度在0.5s时可以忽略不计。我在循环中做的工作量很小,其他的都是矢量化的。附件是我正在使用的张量大小的MWE。

我读到C++应该更快。用C++编写代码会比torchscript更好吗?有没有其他方法可以显著提高循环的运行时间?谢谢

import torch
import time
n = 11300
f = 7800
batch_size = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
inp1 = torch.randint(0, n, size=[batch_size, n, 4], device=device)
inp2 = torch.randint(0, n, size=[batch_size, f, 3], device=device)
inp3 = torch.randint(0, f, size=[batch_size, n, 2], device=device)
inp4 = torch.randint(0, n, size=[batch_size, n, 2], device=device)
inp5 = torch.randint(0, n, size=[batch_size, n, 2], device=device)

batch_list = torch.arange(batch_size, device=device)
mask1 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask2 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask3 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
start_time = time.time()
for i in range(n):
if torch.all(~mask1[:, i]):
continue
batch_list_tmp = batch_list[mask1[:, i]]
mask2[mask1[:, i], i] = False
mask3[batch_list_tmp.unsqueeze(-1), inp1[:, i][:, ::2][batch_list_tmp]] = 0
mask1[batch_list_tmp, i] = False
concat_list_1 = inp5[
batch_list.unsqueeze(-1), inp4[batch_list.unsqueeze(-1), inp1[batch_list, i]].view([batch_size, -1])]
concat_list_2 = inp5[batch_list.unsqueeze(-1).unsqueeze(-1), inp4[
batch_list.unsqueeze(-1).unsqueeze(-1), concat_list_1[batch_list]].view([batch_size, 8, -1])[
batch_list]].view([batch_size, 8, -1])
closure = inp2[batch_list.unsqueeze(-1).unsqueeze(-1),
inp3[batch_list.unsqueeze(-1).unsqueeze(-1),
concat_list_2[batch_list]].view(batch_size, 8, -1)[batch_list]].view([batch_size, 8, -1])
mask1[batch_list_tmp.unsqueeze(-1), closure.view([batch_size, -1])[batch_list_tmp]] = False
end_time = time.time()
print(end_time-start_time) # takes ~5-7 seconds on server

一个想法是将此操作从流程中移除。例如,如果这是培训前的预处理步骤,您可以让一组工人提前预处理多个批次,并对它们进行排队,这样第一个准备好的批次出列的等待时间可以忽略不计或小得多。

另一个想法是,通过改变操作序列(即有许多操作序列来实现最终结果,其中一些可能是可并行的(,实际上可能会对for循环进行顶级反序列化。如果没有对你试图实现的目标的描述,很难判断这里是否存在这种情况。

最新更新