首先,这不是家庭作业问题;这是一个与我的工作相关的真实问题的抽象。我真的很感谢所有的意见!
我需要运行类似于下面的计算,按顺序运行数万次,它的计算时间会显著影响我模拟的总持续时间:
在这个抽象中:
- 我有60000个小部件和每个小部件类的一系列价格,"小部件价格">
- 我有一个2D映射
price_mapping
,其中30000行中的每一行对应于购买这些小部件的篮子,并且60000列对应于与CCD_ 2。false
的Bool
值表示小部件不在篮子中,true
的值表示它们在篮子中 - 我想生成一个数组,其中计算了30000个篮子中的每一个(每排
price_mapping
)
显示了数据结构的图示
下面是我写的一些代码,测试了我能想到的3种不同的方法。第一个,包括np.mean
和一个常规python列表理解,第二个包括np.average
、np.tile
。以及逐元素矩阵乘法,第三个包括np.ma
、np.tile
和np.mean
。
import numpy as np
import time
number_of_widgets = 60000
number_of_orders = 30000
widget_prices = np.random.uniform(0, 1, number_of_widgets)
price_mapping = np.random.randint(2, size=(number_of_orders, number_of_widgets), dtype=bool)
# method 1, using np.mean and a python list comprehension
start = time.time()
mean_price_array_1 = np.array([np.mean(widget_prices[price_mapping[i, :]]) for i in range(number_of_orders)])
end = time.time()
print('method 1 took ' + str(end - start) + ' seconds')
# method 2, using np.average, np.tile, and element-wise matrix multiplication
start = time.time()
mean_price_array_2 = np.average(np.tile(widget_prices, (number_of_orders, 1)) * price_mapping, weights=price_mapping,
axis=1)
end = time.time()
print('method 2 took ' + str(end - start) + ' seconds')
# method 3, using np.ma (masked array), np.tile, and np.mean
start = time.time()
mean_price_array_3 = np.ma.array(np.tile(widget_prices, (number_of_orders, 1)), mask=~price_mapping).mean(axis=1)
end = time.time()
print('method 3 took ' + str(end - start) + ' seconds')
这些是我得到的结果:
method 1 took 10.472509145736694 seconds
method 2 took 28.92689061164856 seconds
method 3 took 18.18838620185852 second
第一个是计算时间最快的,但对于我的需求来说仍然太慢了。
有什么方法可以提高对清单的理解吗?
提前感谢!!
-S
对于price_mapping
作为每次迭代从widget_prices
中选择元素的布尔掩码,我们可以简单地将matrix-multiplication
与np.dot
一起用于矢量化解决方案,并有望以更快的方式,如so-
price_mapping.dot(widget_prices)/price_mapping.sum(1)
计算每行非零的一种更快的方法是使用np.count_nonzero
。因此,另一种方式是
price_mapping.dot(widget_prices)/np.count_nonzero(price_mapping, axis=1)
如果你想快速计算,而numpy没有帮助,那么我建议使用numba。
1) 创建一个函数,用于列表理解的循环整数。2) 将@jit-decorator放在方法的开头,该方法将在多核PC上以并行方式运行它。3) 从numba导入jit