我有一个3D数组,由每个带内的几个数字组成。是否有函数返回数组满足MULTIPLE条件的索引位置?
我尝试了以下方法:
index_pos = numpy.where(
array[:,:,0]==10 and array[:,:,1]==15 and array[:,:,2]==30)
它返回错误:
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
您实际上有一种特殊情况,在这种情况下,执行以下操作会更简单、更高效:
创建数据:
>>> arr
array([[[ 6, 9, 4],
[ 5, 2, 1],
[10, 15, 30]],
[[ 9, 0, 1],
[ 4, 6, 4],
[ 8, 3, 9]],
[[ 6, 7, 4],
[ 0, 1, 6],
[ 4, 0, 1]]])
预期值:
>>> index_pos = np.where((arr[:,:,0]==10) & (arr[:,:,1]==15) & (arr[:,:,2]==30))
>>> index_pos
(array([0]), array([2]))
使用广播同时做到这一点:
>>> arr == np.array([10,15,30])
array([[[False, False, False],
[False, False, False],
[ True, True, True]],
[[False, False, False],
[False, False, False],
[False, False, False]],
[[False, False, False],
[False, False, False],
[False, False, False]]], dtype=bool)
>>> np.where( np.all(arr == np.array([10,15,30]), axis=-1) )
(array([0]), array([2]))
如果你想要的索引不是连续的,你可以这样做:
ind_vals = np.array([0,2])
where_mask = (arr[:,:,ind_vals] == values)
尽可能进行广播。
在@Jamie评论的刺激下,一些有趣的事情需要考虑:
arr = np.random.randint(0,100,(5000,5000,3))
%timeit np.all(arr == np.array([10,15,30]), axis=-1)
1 loops, best of 3: 614 ms per loop
%timeit ((arr[:,:,0]==10) & (arr[:,:,1]==15) & (arr[:,:,2]==30))
1 loops, best of 3: 217 ms per loop
%timeit tmp = (arr == np.array([10,15,30])); (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
1 loops, best of 3: 368 ms per loop
问题变成了,为什么会这样?:
首次检验:
%timeit (arr[:,:,0]==10)
10 loops, best of 3: 51.2 ms per loop
%timeit (arr == np.array([10,15,30]))
1 loops, best of 3: 300 ms per loop
可以预期CCD_ 1在更坏的情况下将是CCD_ 2的速度的1/3。有人知道为什么不是这样吗?
然后,当组合最终轴时,有许多方法可以实现这一点。
tmp = (arr == np.array([10,15,30]))
method1 = np.all(tmp,axis=-1)
method2 = (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
method3 = np.einsum('ij,ij,ij->ij',tmp[:,:,0] , tmp[:,:,1] , tmp[:,:,2])
np.allclose(method1,method2)
True
np.allclose(method1,method3)
True
%timeit np.all(tmp,axis=-1)
1 loops, best of 3: 318 ms per loop
%timeit (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
10 loops, best of 3: 68.2 ms per loop
%timeit np.einsum('ij,ij,ij->ij',tmp[:,:,0] , tmp[:,:,1] , tmp[:,:,2])
10 loops, best of 3: 38 ms per loop
einsum加速在其他地方有很好的定义,但对我来说,all
和连续的&
之间有这样的差异似乎很奇怪。
and
运算符在这种情况下不起作用。
index_pos = numpy.where(array[:,:,0]==10 and array[:,:,1]==15 and array[:,:,2]==30)
试试看:
index_pos = numpy.where((array[:,:,0]==10) & (array[:,:,1]==15) & (array[:,:,2]==30))
问题是使用了原生Python and
关键字,它在数组上的行为与您想要的不同。
相反,请尝试使用numpy.logical_and
函数。
cond1 = np.logical_and(array[:,:,0]==10, array[:,:,1]==15)
cond2 = np.logical_and(cond1, array[:,:,2]==30)
index_pos = numpy.where(cond2)
您甚至可以创建自己的logical_and
版本,该版本接受任意数量的条件:
def my_logical_and(*args):
return reduce(np.logical_and, args)
condition_locs_and_vals = [(0, 10), (1, 15), (2, 30)]
conditions = [array[:,:,x] == y for x,y in conditition_locs_and_vals]
my_logical_and(*conditions)
使用逐位和(&
(是可行的,但只是巧合。按位和用于比较位或arr == np.array([10,15,30])
0类型。使用它来比较数值数组的真值是不可靠的(例如,如果您突然需要对条目求值为True
的位置进行索引,而不是实际首先转换为bool
数组(。确实应该使用logical_and
而不是&
(即使它带有速度惩罚(。
此外,用&
将任意条件列表链接在一起,无论是阅读还是打字都会很痛苦。为了代码的可重用性,这样以后的程序员就不必围绕&
运算符的一堆附属子句进行更改,最好将各个条件单独存储,然后使用类似上面的函数来组合它们。