仅保留位置0处具有一个唯一值的子数组



从Numpy nd数组开始:

>>> arr
[
[
[10, 4, 5, 6, 7],
[11, 1, 2, 3, 4],
[11, 5, 6, 7, 8]
],
[
[12, 4, 5, 6, 7],
[12, 1, 2, 3, 4],
[12, 5, 6, 7, 8]
],
[
[15, 4, 5, 6, 7],
[15, 1, 2, 3, 4],
[15, 5, 6, 7, 8]
],
[
[13, 4, 5, 6, 7],
[13, 1, 2, 3, 4],
[14, 5, 6, 7, 8]
],
[
[10, 4, 5, 6, 7],
[11, 1, 2, 3, 4],
[12, 5, 6, 7, 8]
]
]

我只想保留在位置0只有一个唯一值的3个子阵列的序列,以获得以下内容:

>>> new_arr
[
[
[12, 4, 5, 6, 7],
[12, 1, 2, 3, 4],
[12, 5, 6, 7, 8]
],
[
[15, 4, 5, 6, 7],
[15, 1, 2, 3, 4],
[15, 5, 6, 7, 8]
]
]

从初始阵列中,arr[0]arr[3]arr[4]被丢弃,因为它们在位置0都具有一个以上的唯一值(分别为[10, 11][13, 14][10, 11, 12](。

我试着摆弄numpy.unique(),但只能在所有子数组中的0位置获得全局唯一值,这不是这里所需要的。

--编辑

以下内容似乎让我更接近解决方案:

>>> np.unique(arr[0, :, 0])
array([10, 11])

但我不知道如何在不使用Python循环的情况下,为arr的每个子数组获得比这更高的一个级别,并为此设置一个条件。

我在没有任何转座的情况下完成了这项工作。

arr = np.array(arr)
arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

我很想看看这些方法的比较情况,所以我在这里使用(4000000, 4, 4)的大型数据集对答案进行了基准测试。

结果

--------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------
Name (time in ms)            Min                   Max                  Mean             StdDev                Median                IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_np_arr_T           128.3483 (1.0)        130.5462 (1.0)        129.0869 (1.0)       0.9536 (1.01)       128.5447 (1.0)       1.5660 (1.83)          2;0  7.7467 (1.0)           8           1
test_np_arr             128.5017 (1.00)       131.2399 (1.01)       129.2841 (1.00)      0.9414 (1.0)        128.9724 (1.00)      0.8553 (1.0)           1;1  7.7349 (1.00)          7           1
test_pure_py_set      2,840.2911 (22.13)    2,849.0413 (21.82)    2,844.4716 (22.04)     3.8494 (4.09)     2,846.1608 (22.14)     6.4168 (7.50)          3;0  0.3516 (0.05)          5           1
test_pure_py          3,688.4772 (28.74)    3,750.0933 (28.73)    3,717.3411 (28.80)    24.7294 (26.27)    3,707.3502 (28.84)    37.1902 (43.48)         2;0  0.2690 (0.03)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

这些基准测试使用pytest-benchmark,所以我会制作一个venv来运行这个:

python3 -m venv venv
. ./venv/bin/activate
pip install numpy pytest pytest-benchmark

运行测试:

pytest test_runs.py

test_runs.py

import numpy as np
# No guarantee this will produce sub-arrays with shared first index
ARR = np.random.randint(low=0, high=10, size=(4_000_000, 4, 4)).tolist()
# ARR = [
#     [[10, 4, 5, 6, 7], [11, 1, 2, 3, 4], [11, 5, 6, 7, 8]],
#     [[12, 4, 5, 6, 7], [12, 1, 2, 3, 4], [12, 5, 6, 7, 8]],
#     [[15, 4, 5, 6, 7], [15, 1, 2, 3, 4], [15, 5, 6, 7, 8]],
#     [[13, 4, 5, 6, 7], [13, 1, 2, 3, 4], [14, 5, 6, 7, 8]],
#     [[10, 4, 5, 6, 7], [11, 1, 2, 3, 4], [12, 5, 6, 7, 8]],
# ]
def pure_py(arr):
new_array = []
for i, v in enumerate(arr):
first_elems = [x[0] for x in v]
if all(elem == first_elems[0] for elem in first_elems):
new_array.append(arr[i])
return new_array
def pure_py_set(arr):
new_array = []
for sub_arr in arr:
if len(set(x[0] for x in sub_arr)) == 1:
new_array.append(sub_arr)
return new_array
def np_arr(arr):
return arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]
def np_arr_T(arr):
return arr[(arr[:, :, 0].T == arr[:, 0, 0]).T.all(axis=1)]
def np_not_arr(arr):
arr = np.array(arr)
return arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]
RES = np_not_arr(ARR).tolist()
def test_pure_py(benchmark):
res = benchmark(pure_py, ARR)
assert res == RES
def test_pure_py_set(benchmark):
res = benchmark(pure_py_set, ARR)
assert res == RES
def test_np_arr(benchmark):
ARR_ = np.array(ARR)
res = benchmark(np_arr, ARR_)
assert res.tolist() == RES
def test_np_arr_T(benchmark):
ARR_ = np.array(ARR)
res = benchmark(np_arr_T, ARR_)
assert res.tolist() == RES

受试图以编辑的形式回答问题(我拒绝了,因为这应该是一个答案(的启发,以下是行之有效的方法:

>>> arr[(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)]
[
[
[12, 4, 5, 6, 7],
[12, 1, 2, 3, 4],
[12, 5, 6, 7, 8]
],
[
[15, 4, 5, 6, 7],
[15, 1, 2, 3, 4],
[15, 5, 6, 7, 8]
]
]

诀窍是将结果转换为:

# all 0-th positions of each subarray
arr[:,:,0].T
# the first 0-th position of each subarray 
arr[:,0,0]
# whether each 0-th position equals the first one
(arr[:,:,0].T == arr[:,0,0]).T
# keep only the sub-array where the above is true for all positions
(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)
# lastly, apply this indexing to the initial array
arr[(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)]

好的,我比较了这个问题的两种解决方案。有numpy(由@rchome编写的脚本(而没有它-纯python

new_array = []
for i, v in enumerate(arr):
first_elems = [x[0] for x in v]
if all(elem == first_elems[0] for elem in first_elems):
new_array.append(arr[i])

该代码执行时间=(+-0:00.000015(

arr = np.array(arr)
new_array = arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

该代码执行时间=(+-0:00.000060(

因此,使用numpy大约需要4倍的时间。但我们必须记住,这个数组非常小。也许使用更大的数组numpy会工作得更快:(

--编辑--我已经将数组放大了大约10倍,以下是我的结果:

  • python:0:00:00.000205
  • 编号:0:00:00.002710

所以。也许对于这个任务来说,使用numpy是多余的。

最新更新