我正在编写PyTorch。在torch推理代码之间,我为自己的兴趣添加了一些外围代码。此代码运行良好,但速度太慢。原因可能是迭代。所以,我需要并行和快速的方法来做这件事。
可以在tensor、Numpy或python数组中执行此操作。
我制作了一个名为selective_max
的函数来查找数组中的最大值。但问题是,我不希望在整个数组中有一个最大值,而是在mask
数组指定的特定候选数组中。让我展示一下这个功能的要点(下面展示了代码本身(
输入
x [batch_size , dim, num_points, k]
:x是原始输入,但它通过x.permute(0,2,1,3)
变为[批大小,num_points,dim,k]。
batch_size
是深度学习社会中一个众所周知的定义。在每一个小批量中,都有很多点。并且用dim
长度特征来表示单个点。对于每个特征元素,都有k
的潜在候选者,这是稍后max
函数的目标。
mask [batch_size, num_points, k]
:这个数组类似于没有dim
的x
。其元素是0
或1
。所以,我用它作为屏蔽信号,就像只对1
屏蔽值进行最大运算一样。
请参阅下面的代码并进行解释。我使用3
进行迭代。假设我们针对特定的批次和特定的点。对于特定批次和特定点,x
具有[dim,k]数组。掩模具有[k]阵列,该阵列由0
或1
组成。因此,我从[k]数组中提取非零索引,并将其用于逐个dim提取x
中的特定元素('fork in range(dim('(。
玩具示例
假设我们处于第二个迭代阶段。因此,我们现在有[dim, k]
用于x
,[k]
用于mask
。对于这个玩具示例,i
假定k=3
和dim=4
。x = [[3,2,1],[5,6,4],[9,8,7],[12,11,10]]
、k=[0,1,1]
。因此,输出将是[2,6,8,11]
,而不是[3, 6, 9, 12]
。
上一次尝试
我尝试{ mask.repeat(0,0,1,0) *(element-wise mul) x }
并执行max
操作。但是,"0"可能是最大值,因为x可能在所有数组中都有负值。因此,这将导致错误的操作。
def selective_max2(x, mask): # x : [batch_size , dim, num_points, k] , mask : [batch_size, num_points, k]
batch_size = x.size(0)
dim = x.size(1)
num_points = x.size(2)
k = x.size(3)
device = torch.device('cuda')
x = x.permute(0,2,1,3) # : [batch, num_points, dim, k]
#print('permuted x dimension : ',x.size())
x = x.detach().cpu().numpy()
mask = mask.cpu().numpy()
output = np.zeros((batch_size,num_points,dim))
for i in range(batch_size):
for j in range(num_points):
query=np.nonzero(mask[i][j]) # among mask entries, we get the index of nonzero values.
for k in range(dim): # for different k values, we get the max value.
# query is index of nonzero values. so, using query, we can get the values that we want.
output[i][j][k] = np.max(x[i][j][k][query])
output = torch.from_numpy(output).float().to(device=device)
output = output.permute(0,2,1).contiguous()
return output
免责声明:我已经按照你的玩具示例(然而,在保持通用性的情况下(编写了以下解决方案
第一件事是将k
扩展为x
(将它们都视为PyTorch张量(:
k_expanded = k.expand_as(x)
然后,您选择k_expanded
中存在1
的元素,并将结果张量视为x
行数(写为x.shape[0]
(,将k
中的1
行数(或掩码(视为列数。到目前为止,我们已经选择了要查询最大元素的范围。然后,使用max(1)
沿着行维度(如.sum(0)
所示(找到最大值
values, indices = x[k_expanded == 1].view(x.shape[0], (k == 1).sum(0)).max(1)
values
Out[29]: tensor([ 2, 6, 8, 11])
基准
def find_max_elements_inside_tensor_range(arr, mask, return_indices=False):
mask_expanded = mask.expand_as(arr)
values, indices = x[k_expanded==1].view(x.shape[0], (k == 1).sum(0)).max(1)
return (values, indices) if return_indices else values
刚刚添加了第三个参数,以防您想要获得数字索引
%timeit find_max_elements_inside_tensor_range(x, k)
38.4 µs ± 534 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
注意:上述解决方案也适用于各种形状的张量和遮罩