如何逐个元素查找哪个numpy数组包含最大值



给定一个numpy数组列表,每个数组都有相同的维度,我如何逐个元素找到哪个数组包含最大值?

例如

import numpy as np
def find_index_where_max_occurs(my_list):
    # d = ...  something goes here ...
    return d
a=np.array([1,1,3,1])
b=np.array([3,1,1,1])
c=np.array([1,3,1,1])
my_list=[a,b,c]
array_of_indices_where_max_occurs = find_index_where_max_occurs(my_list)
# This is what I want:
# >>> print array_of_indices_where_max_occurs
# array([1,2,0,0])
# i.e. for the first element, the maximum value occurs in array b which is at index 1 in my_list.

任何帮助都将不胜感激。。。谢谢

如果您想要一个数组,另一个选项是:

>>> np.array((a, b, c)).argmax(axis=0)
array([1, 2, 0, 0])

因此:

def f(my_list):
    return np.array(my_list).argmax(axis=0)

这也适用于多维数组。

为了好玩,我意识到@Lev的原始答案比他的广义编辑快,所以这是广义堆叠版本,比np.asarray版本快得多,但不是很优雅。

np.concatenate((a[None,...], b[None,...], c[None,...]), axis=0).argmax(0)

即:

def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)

一些解释:

我为每个数组添加了一个新的轴:arr[None,...]等效于arr[np.newaxis,...],这与arr[np.newaxis,:,:,:]相同,其中...扩展为适当的数字维度。这是因为np.concatenate将沿着新维度堆叠,该维度是0,因为None在前面。

例如:

In [286]: a
Out[286]: 
array([[0, 1],
       [2, 3]])
In [287]: b
Out[287]: 
array([[10, 11],
       [12, 13]])
In [288]: np.concatenate((a[None,...],b[None,...]),axis=0)
Out[288]: 
array([[[ 0,  1],
        [ 2,  3]],
       [[10, 11],
        [12, 13]]])

如果它有助于理解,这也会起作用:

np.concatenate((a[...,None], b[...,None], c[...,None]), axis=a.ndim).argmax(a.ndim)

其中新轴现在添加在末尾,所以我们必须沿着最后一个轴堆叠并最大化,该轴将是a.ndim。对于abc是2d,我们可以这样做:

np.concatenate((a[:,:,None], b[:,:,None], c[:,:,None]), axis=2).argmax(2)

这相当于我在上面的评论中提到的dstack(如果dstack在数组中不存在,它会添加第三个轴来堆叠(。

测试:

N = 10
M = 2
a = np.random.random((N,)*M)
b = np.random.random((N,)*M)
c = np.random.random((N,)*M)
def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)
def byarray(arrs):
    return np.array(arrs).argmax(axis=0)
def byasarray(arrs):
    return np.asarray(arrs).argmax(axis=0)
def bylist(arrs):
    assert arrs[0].ndim == 1, "ndim must be 1"
    return [np.argmax(x) for x in zip(*arrs)]

In [240]: timeit bystack((a,b,c))
100000 loops, best of 3: 18.3 us per loop
In [241]: timeit byarray((a,b,c))
10000 loops, best of 3: 89.7 us per loop
In [242]: timeit byasarray((a,b,c))
10000 loops, best of 3: 90.0 us per loop
In [259]: timeit bylist((a,b,c))
1000 loops, best of 3: 267 us per loop
[np.argmax(x) for x in zip(*my_list)]

好吧,这只是一个列表,但如果你想的话,你知道如何使它成为一个数组。:(

为了解释它的作用:zip(*my_list)等价于zip(a,b,c),它为您提供了一个循环的生成器。循环中的每个步骤都会给您一个类似(a[i], b[i], c[i])的元组,其中i是循环中的步骤。然后,np.argmax为具有最大值的元素提供该元组的索引。

最新更新