是否可以在Python中对以下代码进行矢量化?当数组的大小变大时,它运行得很慢。
import numpy as np
# A, B, C are 3d arrays with shape (K, N, N).
# Entries in A, B, and C are in [0, 1].
# In the following, I use random values in B and C as an example.
K = 5
N = 10000
A = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
for k in range(K):
for m in [x for x in range(K) if x != k]:
for i in range(N):
for j in range(N):
if A[m, i, j] not in [0, 1]:
if A[k, i, j] == 1:
A[m, i, j] = B[m ,i ,j]
if A[k ,i, j] == 0:
A[m, i, j] = C[m, i, j]
我无法确定一种向量化的方法,但我可以建议使用numba
包来减少计算时间。在这里,您可以使用nogil=True
参数导入njit
以加快您的代码。
from numba import njit
@njit(nogil=True)
def function():
for k in range(K):
for m in [x for x in range(K) if x != k]:
for i in range(N):
for j in range(N):
if A[k, i, j] == 1 and A[m, i, j] not in [0, 1]:
A[m, i, j] = B[m ,i ,j]
if A[k ,i, j] == 0 and A[m, i, j] not in [0, 1]:
A[m, i, j] = C[m, i, j]
%timeit function()
7.35 s ± 252 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
使用njit
和nogil
参数,我花了7秒来运行整个东西,但是没有njit
,我的代码运行了几个小时(现在仍然是)。Python有一个全局解释器锁(GIL)来确保它坚持单线程。通过释放GIL,您可以在多线程中执行代码。但是,在使用nogil=True
时,您必须警惕多线程编程的常见陷阱(一致性、同步、竞争条件等)。
您可以在这里查看有关Numba的文档。https://numba.pydata.org/numba-doc/dev/user/jit.html?highlight=nogil
我可以帮助部分矢量化,这应该会加快一些速度,但我不确定k与m的逻辑,所以没有尝试包括那部分。本质上,你创建了一个蒙版,你想在A
的第2和第3个维度上检查条件。然后在A
和B
或C
之间使用适当的掩码映射:
# A, B, C are 3d arrays with shape (K, N, N).
# Entries in A, B, and C are in [0, 1].
# In the following, I use random values in B and C as an example.
np.random.seed(10)
K = 5
N = 1000
A = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
for k in range(K):
for m in [x for x in range(K) if x != k]:
#if A[m, i, j] not in [0, 1]:
mask_1 = A[k, :, :] == 1
mask_0 = A[k, :, :] == 0
A[m, mask_1] = B[m, mask_1]
A[m, mask_0] = C[m, mask_0]
我省略了A[m, i, j] not in [0, 1]
部分,因为这使得调试变得困难,因为什么都没有发生(A被初始化为全零)。如果你需要包含这样的附加逻辑,只需为它创建另一个掩码,并在每个掩码的逻辑中包含and
。
更新日期:7/6/22如果你想更新上面的代码来移除m
上的循环,那么你可以用k
的所有值初始化一个数组,并使用它来扩展mask
以包括所有3个维度,排除k
与m
匹配的每个值,如下所示:
np.random.seed(10)
K = 5
N = 1000
A_2 = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
K_vals = np.array(range(K))
for k in range(K):
#for m in [x for x in range(K) if x != k]:
#if A[m, i, j] not in [0, 1]:
k_dim_2_skip = K_vals == k
mask_1 = np.tile(A_2[k, :, :] == 1, (K, 1, 1))
mask_1[k_dim_2_skip, :, :] = False
mask_0 = np.tile(A_2[k, :, :] == 0, (K, 1, 1))
mask_0[k_dim_2_skip, :, :] = False
A_2[mask_1] = B[mask_1]
A_2[mask_0] = C[mask_0]
将这些蒙版与您在下面的评论中添加的& np.logical_not...
代码一起使用,这样就可以了。请注意,矢量化越多,为掩码等操作的数组就越大,因此需要在内存消耗方面进行权衡。对于给定的问题,通常有一个平衡点来平衡运行时间和内存使用。