如何过滤带有元素属性条件的 numpy 数组,而不仅仅是值?



如果我想根据值条件过滤 numpy 数组,我可以做到:

arr = np.array([1, 2, 3, 4])
filtered = arr[arr > 2] # [3, 4]

如果我的元素具有我希望过滤的某些属性怎么办?喜欢这个:

arr = np.array([1], [2, 2], [3, 3, 3], [4, 4, 4, 4])
filtered = arr[len(arr) > 2]
# this does not output the desired [[3, 3, 3], [4, 4, 4, 4]], but rather [arr]

你来了:

In [7]: import numpy as np                                                                                                                                                      
In [8]: arr = np.array([[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]])                                                                                                                  
In [9]: check = np.vectorize(lambda element : len(element) > 2)                                                                                                                 
In [10]: arr[check(arr)]                                                                                                                                                        
Out[10]: array([list([3, 3, 3]), list([4, 4, 4, 4])], dtype=object)

您可以矢量化len函数并将其用于所需的任何数组:

In [25]: vv = np.vectorize(len)                                                                                                                                                                             
In [26]: vv(arr)                                                                                                                                                                                            
Out[26]: array([1, 2, 3, 4])
In [28]: arr[vv(arr)>2]                                                                                                                                                                                     
Out[28]: array([list([3, 3, 3]), list([4, 4, 4, 4])], dtype=object)

或:

In [29]: vv = np.vectorize(lambda x: len(x)>2)                                                                                                                                                              
In [30]: arr[vv(arr)]                                                                                                                                                                                       
Out[30]: array([list([3, 3, 3]), list([4, 4, 4, 4])], dtype=object)

和一个本标记(带有当前数组(:

In [29]: vv = np.vectorize(lambda x: len(x)>2)                                                                                                                                                              
In [30]: arr[vv(arr)]                                                                                                                                                                                       
Out[30]: array([list([3, 3, 3]), list([4, 4, 4, 4])], dtype=object)
In [31]: %timeit arr[vv(arr)]                                                                                                                                                                               
31.6 µs ± 385 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [32]: vv = np.vectorize(len)                                                                                                                                                                             
In [33]: %timeit arr[vv(arr)>2]                                                                                                                                                                             
35 µs ± 578 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

注意:这只是针对您的特定情况的答案,但有人可能会争辩说,您可能需要更改数据结构或代码中的任何其他更改。所有这些命题都相当可观,但您必须始终注意的一件事是,有时重新思考问题标题会使问题变得简单得多。

您的示例会产生一个错误:

In [139]: arr = np.array([1], [2, 2], [3, 3, 3], [4, 4, 4, 4])                                                  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-139-f49650a95073> in <module>
----> 1 arr = np.array([1], [2, 2], [3, 3, 3], [4, 4, 4, 4])
TypeError: array() takes from 1 to 2 positional arguments but 4 were given

您是否打算生成一个对象 dtype 数组:

In [139]: arr = np.array([1], [2, 2], [3, 3, 3], [4, 4, 4, 4])                                                  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-139-f49650a95073> in <module>
----> 1 arr = np.array([1], [2, 2], [3, 3, 3], [4, 4, 4, 4])
TypeError: array() takes from 1 to 2 positional arguments but 4 were given
In [140]: arr = np.array([[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]])                                                
/usr/local/bin/ipython3:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
#!/usr/bin/python3
In [141]: arr                                                                                                   
Out[141]: 
array([list([1]), list([2, 2]), list([3, 3, 3]), list([4, 4, 4, 4])],
dtype=object)

(此未来警告由 1.19dev 生成。

在大多数情况下,对象 dtype 数组更像是一个列表,而不是一个数字数组。 因此,我们可以使用列表推导来查找较长的列表:

In [143]: [a for a in arr if len(a)>2]                                                                          
Out[143]: [[3, 3, 3], [4, 4, 4, 4]]
In [146]: timeit [a for a in arr if len(a)>2]                                                                   
1.36 µs ± 3.52 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

===

其他人则建议np.vectorize. 那比较慢。

In [144]: np.vectorize(lambda a: len(a)>2)(arr)                                                                 
Out[144]: array([False, False,  True,  True])
In [145]: arr[np.vectorize(lambda a: len(a)>2)(arr)]                                                            
Out[145]: array([list([3, 3, 3]), list([4, 4, 4, 4])], dtype=object)
In [147]: timeit arr[np.vectorize(lambda a: len(a)>2)(arr)]                                                     
32.3 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

如果列表理解在列表而不是数组上运行,则列表理解速度更快:

In [148]: %%timeit alist=arr.tolist() 
...: [a for a in alist if len(a)>2]                                                                                                        
626 ns ± 2.41 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

列表可以比 numpy 数组更快。

我只建议np.vectorize在您需要利用其将多个数组"广播"到仅适用于标量参数的函数的情况下。 它不是简单迭代的有效替代方案。

最新更新