选择性元素的最大运算,而不是所有元素



我正在编写PyTorch。在torch推理代码之间,我为自己的兴趣添加了一些外围代码。此代码运行良好,但速度太慢。原因可能是迭代。所以,我需要并行和快速的方法来做这件事。

可以在tensor、Numpy或python数组中执行此操作。

我制作了一个名为selective_max的函数来查找数组中的最大值。但问题是,我不希望在整个数组中有一个最大值,而是在mask数组指定的特定候选数组中。让我展示一下这个功能的要点(下面展示了代码本身(

输入

x [batch_size , dim, num_points, k]:x是原始输入,但它通过x.permute(0,2,1,3)变为[批大小,num_points,dim,k]。

batch_size是深度学习社会中一个众所周知的定义。在每一个小批量中,都有很多点。并且用dim长度特征来表示单个点。对于每个特征元素,都有k的潜在候选者,这是稍后max函数的目标。

mask [batch_size, num_points, k]:这个数组类似于没有dimx。其元素是01。所以,我用它作为屏蔽信号,就像只对1屏蔽值进行最大运算一样。

请参阅下面的代码并进行解释。我使用3进行迭代。假设我们针对特定的批次和特定的点。对于特定批次和特定点,x具有[dim,k]数组。掩模具有[k]阵列,该阵列由01组成。因此,我从[k]数组中提取非零索引,并将其用于逐个dim提取x中的特定元素('fork in range(dim('(。

玩具示例

假设我们处于第二个迭代阶段。因此,我们现在有[dim, k]用于x[k]用于mask。对于这个玩具示例,i假定k=3dim=4x = [[3,2,1],[5,6,4],[9,8,7],[12,11,10]]k=[0,1,1]。因此,输出将是[2,6,8,11],而不是[3, 6, 9, 12]

上一次尝试

我尝试{ mask.repeat(0,0,1,0) *(element-wise mul) x }并执行max操作。但是,"0"可能是最大值,因为x可能在所有数组中都有负值。因此,这将导致错误的操作。

def selective_max2(x, mask): # x : [batch_size , dim, num_points, k] , mask : [batch_size, num_points, k]
batch_size = x.size(0)
dim = x.size(1)
num_points = x.size(2)
k = x.size(3)
device = torch.device('cuda')
x = x.permute(0,2,1,3) # : [batch, num_points, dim, k]
#print('permuted x dimension : ',x.size())
x = x.detach().cpu().numpy()
mask = mask.cpu().numpy()
output = np.zeros((batch_size,num_points,dim))
for i in range(batch_size):
for j in range(num_points):
query=np.nonzero(mask[i][j]) # among mask entries, we get the index of nonzero values.
for k in range(dim): # for different k values, we get the max value.
# query is index of nonzero values. so, using query, we can get the values that we want.
output[i][j][k] = np.max(x[i][j][k][query])
output = torch.from_numpy(output).float().to(device=device)
output = output.permute(0,2,1).contiguous()
return output

免责声明:我已经按照你的玩具示例(然而,在保持通用性的情况下(编写了以下解决方案

第一件事是将k扩展为x(将它们都视为PyTorch张量(:

k_expanded = k.expand_as(x)

然后,您选择k_expanded中存在1的元素,并将结果张量视为x行数(写为x.shape[0](,将k中的1行数(或掩码(视为列数。到目前为止,我们已经选择了要查询最大元素的范围。然后,使用max(1)沿着行维度(如.sum(0)所示(找到最大值

values, indices = x[k_expanded == 1].view(x.shape[0], (k == 1).sum(0)).max(1)
values
Out[29]: tensor([ 2,  6,  8, 11])

基准

def find_max_elements_inside_tensor_range(arr, mask, return_indices=False):
mask_expanded = mask.expand_as(arr)
values, indices = x[k_expanded==1].view(x.shape[0], (k == 1).sum(0)).max(1)
return (values, indices) if return_indices else values

刚刚添加了第三个参数,以防您想要获得数字索引

%timeit find_max_elements_inside_tensor_range(x, k)
38.4 µs ± 534 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

注意:上述解决方案也适用于各种形状的张量和遮罩

最新更新