在 numpy 中矢量化 for 循环以计算胶带重叠



我正在使用python创建一个应用程序来计算胶带重叠(对分配器进行建模,将产品应用于旋转的滚筒(。

我有一个运行正常的程序,但速度真的很慢。我正在寻找一种解决方案来优化用于填充 numpy 数组的for循环。有人可以帮我矢量化下面的代码吗?

import numpy as np
import matplotlib.pyplot as plt
# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel
drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))
# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
to_return = np.zeros(drum.shape)
for index, v in np.ndenumerate(to_return):
if upper == True:
if index[0] * coef + intercept > index[1]:
to_return[index] = 1
else:
if index[0] * coef + intercept <= index[1]:
to_return[index] = 1
return to_return

def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
to_return = np.zeros(drum.shape)
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
to_return = t1 + t2
return np.where(to_return == 2, 1, 0)
single_band = get_band(drum, 1 / 10, 130, bandwidth=15)
# Visualize the result !
plt.imshow(single_band)
plt.show()

Numba 为我的代码创造了奇迹,将运行时间从 5.8 秒减少到 86 毫秒(特别感谢 @Maarten-vd-Sande(:

from numba import jit
@jit(nopython=True, parallel=True)
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
...

仍然欢迎使用numpy的更好解决方案;-(

这里根本不需要任何循环。您实际上有两种不同的line_mask功能。两者都不需要显式循环,但你可能会通过用ifelse中的一对for循环重写它来获得显着的加速,而不是在for循环中ifelse,后者被多次评估。

真正令人费解的事情是正确矢量化您的代码,以便在没有任何循环的情况下对整个数组进行操作。这是line_mask的矢量化版本:

def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
r = np.arange(drum.shape[0]).reshape(-1, 1)
c = np.arange(drum.shape[1]).reshape(1, -1)
comp = c.__lt__ if upper else c.__ge__
return comp(r * coef + intercept)

rc的形状设置为(m, 1)(n, 1),以便(m, n)结果称为广播,并且是 numpy 中矢量化的主要内容。

更新line_mask的结果是一个布尔掩码(顾名思义(而不是浮点数组。这使得它更小,并希望完全绕过浮点操作。您现在可以重写get_band以使用掩码而不是加法:

def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
return t1 & t2

程序的其余部分应保持不变,因为这些函数保留所有接口。

如果你愿意,你可以用三行(仍然有点清晰(重写大部分程序:

coeff = 1/10
intercept = 130
bandwidth = 15
r, c = np.ogrid[:drum.shape[0], :drum.shape[1]]
check = r * coeff + intercept
single_band = ((check + bandwidth / 2 > c) & (check - bandwidth /  2 <= c))

最新更新