给定两个数组a
和b
,如何有效地找出b
中所有元素在a
中值相等的组合?
下面是一个例子:
给定
a = [0, 0, 0, 1, 1, 2, 2, 2, 2]
b = [1, 2, 4, 5, 9, 3, 7, 22, 10]
如何计算
c = [[1, 2],
[1, 4],
[2, 4],
[5, 9],
[3, 7],
[3, 22],
[3, 10],
[7, 22],
[7, 10],
[22, 10]]
?
a
可以假定是排序的。
我可以用循环来做这个,比如:
import torch
a = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2])
b = torch.tensor([1, 2, 4, 5, 9, 3, 7, 22, 10])
jumps = torch.cat((torch.tensor([0]),
torch.where(a.diff() > 0)[0] + 1,
torch.tensor([len(a)])))
cs = []
for i in range(len(jumps) - 1):
cs.append(torch.combinations(b[jumps[i]:jumps[i + 1]]))
c = torch.cat(cs)
是否有有效的方法来避免循环?该解决方案应该适用于CPU和CUDA。同时,解的运行时间应为O(m*m),其中m是a
中相等元素的最大个数,而不是O(n*n),其中n是a
的长度。
我更喜欢pytorch的解决方案,但我对numpy的解决方案也很好奇。
我认为使用torch的开销只适用于更大的数据集,因为在函数中基本上没有计算困难,我认为您可以使用
实现相同的结果:from collections import Counter
def find_combinations1(a, b):
count_a = Counter(a)
combinations = []
for x in set(b):
if count_a[x] == b.count(x):
combinations.append(x)
return combinations
或者更简单的:
def find_combinations2(a, b):
return list(set(a) & set(b))
对于pytorch
,我认为最简单的方法是:
import torch
def find_combinations3(a, b):
a = torch.tensor(a)
b = torch.tensor(b)
eq = torch.eq(a, b.view(-1, 1))
indices = torch.nonzero(eq)
return indices[:, 1]
这个选项的时间复杂度为O(n*m),其中n是a的大小,m是b的大小,O(n+m)是张量的内存。