将3d numpy数组的赋值矢量化,该赋值以其他维度的关联值为条件



是否可以在Python中对以下代码进行矢量化?当数组的大小变大时,它运行得很慢。

import numpy as np
# A, B, C are 3d arrays with shape (K, N, N). 
# Entries in A, B, and C are in [0, 1]. 
# In the following, I use random values in B and C as an example.
K = 5
N = 10000
A = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
for k in range(K):
for m in [x for x in range(K) if x != k]:
for i in range(N):
for j in range(N):
if A[m, i, j] not in [0, 1]:
if A[k, i, j] == 1:
A[m, i, j] = B[m ,i ,j]
if A[k ,i, j] == 0:
A[m, i, j] = C[m, i, j]

我无法确定一种向量化的方法,但我可以建议使用numba包来减少计算时间。在这里,您可以使用nogil=True参数导入njit以加快您的代码。

from numba import njit
@njit(nogil=True)
def function():
for k in range(K):
for m in [x for x in range(K) if x != k]:
for i in range(N):
for j in range(N):
if A[k, i, j] == 1 and A[m, i, j] not in [0, 1]:
A[m, i, j] = B[m ,i ,j]
if A[k ,i, j] == 0 and A[m, i, j] not in [0, 1]:
A[m, i, j] = C[m, i, j]
%timeit function()
7.35 s ± 252 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

使用njitnogil参数,我花了7秒来运行整个东西,但是没有njit,我的代码运行了几个小时(现在仍然是)。Python有一个全局解释器锁(GIL)来确保它坚持单线程。通过释放GIL,您可以在多线程中执行代码。但是,在使用nogil=True时,您必须警惕多线程编程的常见陷阱(一致性、同步、竞争条件等)。

您可以在这里查看有关Numba的文档。https://numba.pydata.org/numba-doc/dev/user/jit.html?highlight=nogil

我可以帮助部分矢量化,这应该会加快一些速度,但我不确定k与m的逻辑,所以没有尝试包括那部分。本质上,你创建了一个蒙版,你想在A的第2和第3个维度上检查条件。然后在ABC之间使用适当的掩码映射:

# A, B, C are 3d arrays with shape (K, N, N). 
# Entries in A, B, and C are in [0, 1]. 
# In the following, I use random values in B and C as an example.
np.random.seed(10)
K = 5
N = 1000
A = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
for k in range(K):
for m in [x for x in range(K) if x != k]:
#if A[m, i, j] not in [0, 1]:
mask_1 = A[k, :, :] == 1
mask_0 = A[k, :, :] == 0
A[m, mask_1] = B[m, mask_1]
A[m, mask_0] = C[m, mask_0]

我省略了A[m, i, j] not in [0, 1]部分,因为这使得调试变得困难,因为什么都没有发生(A被初始化为全零)。如果你需要包含这样的附加逻辑,只需为它创建另一个掩码,并在每个掩码的逻辑中包含and

更新日期:7/6/22如果你想更新上面的代码来移除m上的循环,那么你可以用k的所有值初始化一个数组,并使用它来扩展mask以包括所有3个维度,排除km匹配的每个值,如下所示:

np.random.seed(10)
K = 5
N = 1000
A_2 = np.zeros((K, N, N))
B = np.random.normal(0, 1, (K, N, N))
C = np.random.normal(0, 1, (K, N, N))
K_vals = np.array(range(K))
for k in range(K):
#for m in [x for x in range(K) if x != k]:
#if A[m, i, j] not in [0, 1]:
k_dim_2_skip = K_vals == k
mask_1 = np.tile(A_2[k, :, :] == 1, (K, 1, 1))
mask_1[k_dim_2_skip, :, :] = False
mask_0 = np.tile(A_2[k, :, :] == 0, (K, 1, 1))
mask_0[k_dim_2_skip, :, :] = False
A_2[mask_1] = B[mask_1]
A_2[mask_0] = C[mask_0]

将这些蒙版与您在下面的评论中添加的& np.logical_not...代码一起使用,这样就可以了。请注意,矢量化越多,为掩码等操作的数组就越大,因此需要在内存消耗方面进行权衡。对于给定的问题,通常有一个平衡点来平衡运行时间和内存使用。

相关内容

  • 没有找到相关文章

最新更新