Python - 如何加快从另一个 numpy 数组计算创建 numpy 数组的 for 循环



首先,对模糊的标题表示歉意,我想不出这个问题的合适名称。

我有 3 个以下格式的 numpy 数组:

N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ...几十万元素长

e1 = [1, 0, 0]

e2 = [0, 1, 0]

这个想法是创建第四个数组 'v',它应该与 'N' 具有相同的维度,但将根据 if 语句给出值。这是我目前拥有的应该更好地解释这个问题:

v = np.zeros([len(N), 3])    
for i in range(0, len(N)):
    if((N*e1)[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

这段代码执行我需要它执行的操作,但执行时间比预期的要长(> 5 分钟)。我可以使用任何形式的列表理解或类似概念来提高代码的效率?

您可以使用

numpy.where替换if-else并用广播对过程进行矢量化,这里有一个带有numpy.where的选项:

import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))

这里有一些基准:

1) 数据设置

N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N
#array([[3, 5, 0],
#       [5, 0, 8],
#       [4, 6, 0],
#       ..., 
#       [9, 4, 2],
#       [6, 9, 3],
#       [2, 9, 2]])
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])

2)时间

def forloop():
    v = np.zeros([len(N), 3]);    
​
    for i in range(0, len(N)):
        if((N*e1)[i,0] != 0):
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)
    return v
def forloop2():
    v = np.zeros([len(N), 3])    
​
    # Only calculate this one time.
    my_product = N*e1
​
    for i in range(0, len(N)):
        if my_product[i,0] != 0:
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)               
    return v
%timeit forloop()
10 loops, best of 3: 25.5 ms per loop
%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop    
%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop

3) 所有方法的结果检查

v1 = forloop()   
v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
v3 = forloop2()
(v3 == v1).all()
# True
(v1 == v2).all()
# True

我不确定你想做什么,但我知道为什么这个特定的代码对你来说这么慢。最严重的罪犯是(N*e1).这是一个简单的计算,它使用 numpy 运行得非常快,但你在循环中执行它,len(N)次!

通过将代码拉出循环,我可以在不到 15 秒的时间内在我的机器上用 N == 1000000 执行您的代码。下面是示例。

v = np.zeros([len(N), 3])    
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
    if my_product[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

另一个答案演示了如何避免 for 循环和 if 语句,以降低可读性代码为代价获得大量额外的速度。

相关内容

  • 没有找到相关文章

最新更新