如果数组元素出现多次,则筛选2D numpy数组



我想删除在2D数组中共享元素的行。例如:

array = [0 1]
        [2 3]
        [4 0]
        [0 4]
filtered_array = [2 3]

编辑:列位置与无关

以下是使用NumPy broadcasting-的矢量化方法

def filter_rows(arr):
    # Detect matches along same columns for both cols
    samecol_mask1 = arr[:,None,0] == arr[:,0]
    samecol_mask2 = arr[:,None,1] == arr[:,1]
    samecol_mask = np.triu(samecol_mask1 | samecol_mask2,1)
    # Detect matches across the two cols
    diffcol_mask = arr[:,None,0] == arr[:,1]
    # Get the combined matching mask
    mask = samecol_mask | diffcol_mask
    # Get the indices of the mask which gives us the row IDs that have matches
    # across either same or different columns. Delete those rows for output. 
    dup_rowidx = np.unique(np.argwhere(mask))
    return np.delete(arr,dup_rowidx,axis=0)

展示各种场景的样本运行

案例#1:相同和不同列之间的多个匹配

In [313]: arr
Out[313]: 
array([[0, 1],
       [2, 3],
       [4, 0],
       [0, 4]])
In [314]: filter_rows(arr)
Out[314]: array([[2, 3]])

案例#2:沿相同列匹配

In [319]: arr
Out[319]: 
array([[ 0,  1],
       [ 2,  3],
       [ 8, 10],
       [ 0,  4]])
In [320]: filter_rows(arr)
Out[320]: 
array([[ 2,  3],
       [ 8, 10]])

案例#3:沿不同列匹配

In [325]: arr
Out[325]: 
array([[ 0,  1],
       [ 2,  3],
       [ 8, 10],
       [ 7,  0]])
In [326]: filter_rows(arr)
Out[326]: 
array([[ 2,  3],
       [ 8, 10]])

案例4:在同一行中匹配

In [331]: arr
Out[331]: 
array([[ 0,  1],
       [ 3,  3],
       [ 8, 10],
       [ 7,  0]])
In [332]: filter_rows(arr)
Out[332]: array([[ 8, 10]])

只是@Divakar令人印象深刻的解决方案的替代方案。这种方法无论如何都更糟糕(尤其是效率),但对于非愚蠢的大师来说可能更容易理解。

import numpy as np
def filter_(x):
    unique = np.unique(x) # 1
    unique_mapper = [np.where(x == z)[0] for z in unique] # 2
    filtered_unique_mapper = list(map(lambda x: x if len(x) > 1 else [], unique_mapper)) # 3
    all = np.concatenate(filtered_unique_mapper) # 4
    to_delete = np.unique(all) # 5
    return np.delete(x, all, axis=0)
# 1 get global unique values
# 2 for each unique value: get all rows with this value
#   -> multiple entries for one unique value: row's collide!
# 3 remove entries from above, if only <= 1 rows hold that unique value
# 4 collect all rows, which collided somehow
# 5 remove multiple entries from above

相关内容

  • 没有找到相关文章

最新更新