NumPy检查二维数组是否为二维数组的子集



我想检查数组b是否是数组a的子集。我的意思是,我想检查b的所有元素是否都在a中找到。

这是我的代码:

import numpy as np
a = np.array([[1,7,9],[8,3,12],[101,-74,0.5]])
b = np.array([[1,9],[8,12],[101,0.5]])
print a
print b

这是输出

阵列

[[   1.     7.     9. ]
 [   8.     3.    12. ]
 [ 101.   -74.     0.5]]

阵列b

[[   1.     9. ]
 [   8.    12. ]
 [ 101.     0.5]]

有没有办法检查b是否是a的子集?

编辑:附加信息:

根据下面的评论,我应该澄清一下,我需要知道数组b是否是数组a的子集——如果子集中甚至缺少一个元素,那么我正在寻找一种方法来检查这一点。我不需要指示元素在子集中的哪里缺失,只需要知道它缺失了。如果可以提供关于缺失元素的额外信息,那么这将是一个奖励,但这不是硬性要求。很抱歉没有早点解决这个问题。

我将问题表述为子集的理由是,如果一个数组是另一个数组的子集,那么这对我来说意味着子集数组的所有值都存在于较大的数组中。

我想你想要numpy.in1d,类似这样的东西:

import numpy as np
a = np.array([[1,7,9],[8,3,12],[101,-74,0.5]])
b = np.array([[1,9],[8,12],[101,0.5]])
np.in1d(b.ravel(), a.ravel()).all()

这应该有效:

set(np.unique(b)).issubset(set(np.unique(a)))

EDIT:上面的代码返回TrueFalse,而不是布尔值的列向量。从@Eelco Hoogendoorn对您的问题的评论中,我知道您实际上有兴趣检查b是否是a的相应的子集,对吧?假设这是正确的问题描述,下面的一行应该可以工作:

np.array([[set(bi).issubset(set(ai))] for ai, bi in zip(map(tuple, a), map(tuple, b))])

上面的代码简单易读,不需要第三方依赖关系。无可否认,这是一个快速而肮脏的解决方案,因为正如@Bi-Rico正确指出的那样,这种方法可能非常低效。如果需要处理大型数组,则应坚持使用矢量化算法。

如果您想比较,一种方法是首先对它们进行分组:

a = np.array([[1,7,9],[8,3,12],[101,-74,0.5]])
b = np.array([[1,9],[8,12],[101,0.5]])
c = np.array([[1,9],[8,12],[101,-74.]])
def bycols(arr):
    tr=arr.T.copy()
    type=np.dtype((np.void,tr.strides[0]))
    return tr.view(type).squeeze()
A,B,C=[bycols(x) for x in (a,b,c)]    

那么A、B、C只是表示列的字节数组:

In [5]: [x.shape for x in (A,B,C)]
Out[5]: [(3,), (2,), (2,)]

你现在可以用np.in1d测试归属:

In [6]: np.in1d(C,A)
Out[6]: array([ True, False], dtype=bool)
In [7]: np.in1d(B,A)
Out[7]: array([ True,  True], dtype=bool)

但是:

In [8]: np.in1d(c,a)
Out[8]: array([ True,  True,  True,  True,  True,  True], dtype=bool)

因为np1d适用于扁平阵列。

如果我正确阅读了你的问题(测试a和b中对应的每一行,如果b中的行是a中行的子集),这应该能有效而正确地完成:

import numpy_indexed as npi
rowsa = np.indices(a.shape)[0]
rowsb = np.indices(b.shape)[0]
# test for each value-rowidx pair in b if it is contained in a
c = npi.contains((a.flatten(), rowsa.flatten()), (b.flatten(), rowsb.flatten()))
# check that all elements on a row are contained
row_is_subset = c.reshape(b.shape).all(axis=1)

您需要安装numpy_indexed包(免责声明:我是它的作者)

最新更新