我怎样才能更有效地使用我的numpy数组?



对于上下文,我正在解决2021年代码出现的第6天问题,并想尝试在python列表上使用numpy数组,因为从我目前的理解来看,它们应该更快。但遇到一个问题,我的解决方案打印正确的答案但以年龄来完成计算为number_of_days_to_cycle_through变量尺度。

我想帮助理解为什么会发生难以置信的长缩放,以及如何注意/防止我的代码中的错误?(lantern_fish_array int64 numpy数组)

def iterate_through_one_day(lantern_fish_array):
iterator = 0
copy_of_lantern_fish_array = lantern_fish_array.copy()
for fish in copy_of_lantern_fish_array:
if fish == 0:
lantern_fish_array = np.append(lantern_fish_array, 8)
lantern_fish_array[iterator] = 6
else:
lantern_fish_array[iterator] -= 1
iterator += 1
del copy_of_lantern_fish_array
return new_lantern_fish_array

def solve_part_1(lantern_fish_array):
num_of_days_to_cycle_through = 256
while num_of_days_to_cycle_through != 0:
lantern_fish_array = iterate_through_one_day(lantern_fish_array)
num_of_days_to_cycle_through -= 1
return lantern_fish_array.size

Numpy数组速度很快,因为操作是并行进行的(矢量化)。对于使用循环,它们可能比列表慢。因此,除非您的操作可以用并行方式(矢量化)表示,否则在np数组上使用循环/迭代操作与在列表上使用相同操作相比可能不会有任何好处。

请参阅有关广播和向量化的参考资料,

https://realpython.com/numpy-array-programming/

https://unidata.github.io/python-training/workshop/NumPy/numpy-broadcasting-and-vectorization/

还有一些建议,当你无法避免循环时,

快速Numpy循环

你可以这样做:

import numpy as np
# numpy style
def iterate_through_one_day_np(lantern_fish_array):
new_array=np.copy(lantern_fish_array)
mask_0 = new_array==0

new_array[mask_0] = 6 
new_array[~mask_0] -= 1  # or new_array[~mask_0] = new_array[~mask_0] - 1
new_array = np.append(new_array, [8]*np.sum(mask_0))
return new_array
# your code for reference
def iterate_through_one_day(lantern_fish_array):
iterator = 0
new_array=np.copy(lantern_fish_array)
for fish in lantern_fish_array:
if fish == 0:
new_array = np.append(new_array, 8)
new_array[iterator] = 6
else:
new_array[iterator] -= 1
iterator += 1

return new_array
lantern_fish_array = [0,1,2,3,4,5]
iterate_through_one_day(lantern_fish_array)
# array([6, 0, 1, 2, 3, 8])
iterate_through_one_day(lantern_fish_array)
# array([6, 0, 1, 2, 3, 8])
<标题>

速度测试得到更多的鱼,即50倍的列表:[0,1,2,3,4,5]*50

%%timeit -r 3 -n 3
iterate_through_one_day_np([0,1,2,3,4,5]*50)
# 94.4 µs ± 4.89 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)

%%timeit -r 3 -n 3
iterate_through_one_day([0,1,2,3,4,5]*50)
# 878 µs ± 154 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)

对于出现多次0的数组,使用numpy.where比通过index修改两次值要快得多。对于出现0次数较少的数组,其性能也略快于LHans方法:

>>> def one_day_use_where(lantern_fish_array):
...     mask = lantern_fish_array == 0
...     return np.concatenate([np.where(mask, 6, lantern_fish_array - 1), np.repeat(8, mask.sum())])
...

一些测试:

>>> def one_day_use_index(lantern_fish_array):
...     mask = lantern_fish_array == 0
...     new_array = lantern_fish_array.copy()
...     new_array[mask] = 6
...     new_array[~mask] -= 1
...     return np.concatenate([new_array, np.repeat(8, mask.sum())])
...
>>> a = np.random.randint(0, 10, 10000)   # 0 accounts for about 10%
>>> timeit(lambda: one_day_use_where(a), number=10000)
0.4181383000104688
>>> timeit(lambda: one_day_use_index(a), number=10000)
0.8232910000078846
>>> a = np.random.randint(0, 2, 10000)    # 0 accounts for about 50%
>>> timeit(lambda: one_day_use_where(a), number=10000)
0.544302800000878
>>> timeit(lambda: one_day_use_index(a), number=10000)
1.917074600001797
>>> a = np.random.randint(1, 3, 10000)    # Does not contain 0
>>> timeit(lambda: one_day_use_where(a), number=10000)
0.38596799998776987
>>> timeit(lambda: one_day_use_index(a), number=10000)
0.3989579000044614
>>> a = np.zeros(10000, dtype=int)
>>> # If the proportion of 0 is too high, the performance will be slightly worse
>>> timeit(lambda: one_day_use_np_where(a), number=10000)
0.6532589000125881
>>> timeit(lambda: one_day_use_index(a), number=10000)
0.5977481000008993

最新更新