首先,对模糊的标题表示歉意,我想不出这个问题的合适名称。
我有 3 个以下格式的 numpy 数组:
N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ...几十万元素长
e1 = [1, 0, 0]
e2 = [0, 1, 0]
这个想法是创建第四个数组 'v',它应该与 'N' 具有相同的维度,但将根据 if 语句给出值。这是我目前拥有的应该更好地解释这个问题:
v = np.zeros([len(N), 3])
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
这段代码执行我需要它执行的操作,但执行时间比预期的要长(> 5 分钟)。我可以使用任何形式的列表理解或类似概念来提高代码的效率?
numpy.where
替换if-else并用广播对过程进行矢量化,这里有一个带有numpy.where
的选项:
import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
这里有一些基准:
1) 数据设置:
N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N
#array([[3, 5, 0],
# [5, 0, 8],
# [4, 6, 0],
# ...,
# [9, 4, 2],
# [6, 9, 3],
# [2, 9, 2]])
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])
2)时间:
def forloop():
v = np.zeros([len(N), 3]);
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
def forloop2():
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0:
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
%timeit forloop()
10 loops, best of 3: 25.5 ms per loop
%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop
%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop
3) 所有方法的结果检查:
v1 = forloop()
v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
v3 = forloop2()
(v3 == v1).all()
# True
(v1 == v2).all()
# True
我不确定你想做什么,但我知道为什么这个特定的代码对你来说这么慢。最严重的罪犯是(N*e1)
.这是一个简单的计算,它使用 numpy 运行得非常快,但你在循环中执行它,len(N)
次!
通过将代码拉出循环,我可以在不到 15 秒的时间内在我的机器上用 N == 1000000
执行您的代码。下面是示例。
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
另一个答案演示了如何避免 for 循环和 if 语句,以降低可读性代码为代价获得大量额外的速度。