Numpy:需要最有效的方法来处理1D ndarray中的选择元素,使用2D ndarray的映射来输出1D平均ndar



首先,这不是家庭作业问题;这是一个与我的工作相关的真实问题的抽象。我真的很感谢所有的意见!

我需要运行类似于下面的计算,按顺序运行数万次,它的计算时间会显著影响我模拟的总持续时间:

在这个抽象中:

  • 我有60000个小部件和每个小部件类的一系列价格,"小部件价格">
  • 我有一个2D映射price_mapping,其中30000行中的每一行对应于购买这些小部件的篮子,并且60000列对应于与CCD_ 2。falseBool值表示小部件不在篮子中,true的值表示它们在篮子中
  • 我想生成一个数组,其中计算了30000个篮子中的每一个(每排price_mapping)

显示了数据结构的图示

下面是我写的一些代码,测试了我能想到的3种不同的方法。第一个,包括np.mean和一个常规python列表理解,第二个包括np.averagenp.tile。以及逐元素矩阵乘法,第三个包括np.manp.tilenp.mean

import numpy as np
import time
number_of_widgets = 60000
number_of_orders = 30000
widget_prices = np.random.uniform(0, 1, number_of_widgets)
price_mapping = np.random.randint(2, size=(number_of_orders, number_of_widgets), dtype=bool)
# method 1, using np.mean and a python list comprehension
start = time.time()
mean_price_array_1 = np.array([np.mean(widget_prices[price_mapping[i, :]]) for i in range(number_of_orders)])
end = time.time()
print('method 1 took ' + str(end - start) + ' seconds')
# method 2, using np.average, np.tile, and element-wise matrix multiplication
start = time.time()
mean_price_array_2 = np.average(np.tile(widget_prices, (number_of_orders, 1)) * price_mapping, weights=price_mapping,
axis=1)
end = time.time()
print('method 2 took ' + str(end - start) + ' seconds')
# method 3, using np.ma (masked array), np.tile, and np.mean
start = time.time()
mean_price_array_3 = np.ma.array(np.tile(widget_prices, (number_of_orders, 1)), mask=~price_mapping).mean(axis=1)
end = time.time()
print('method 3 took ' + str(end - start) + ' seconds')

这些是我得到的结果:

method 1 took 10.472509145736694 seconds
method 2 took 28.92689061164856 seconds
method 3 took 18.18838620185852 second

第一个是计算时间最快的,但对于我的需求来说仍然太慢了。

有什么方法可以提高对清单的理解吗?

提前感谢!!

-S

对于price_mapping作为每次迭代从widget_prices中选择元素的布尔掩码,我们可以简单地将matrix-multiplicationnp.dot一起用于矢量化解决方案,并有望以更快的方式,如so-

price_mapping.dot(widget_prices)/price_mapping.sum(1)

计算每行非零的一种更快的方法是使用np.count_nonzero。因此,另一种方式是

price_mapping.dot(widget_prices)/np.count_nonzero(price_mapping, axis=1)

如果你想快速计算,而numpy没有帮助,那么我建议使用numba。

1) 创建一个函数,用于列表理解的循环整数。2) 将@jit-decorator放在方法的开头,该方法将在多核PC上以并行方式运行它。3) 从numba导入jit

最新更新