我有一个numpy数组,比如
original_array=np.arange(5)
我有另一个数组,它存储要更新的索引值
indices=[0,1,2,1,1]
indices
的元素中的重复表示该元素被多次更新。
我还有一个数组,它存储要添加的值。
updation_values=[0.2,0.2,0.4, 0.5, 0.4]
通常我会将阵列更新为
for update_value, index in zip(updation_values, indices):
original_array[index]+=update_value
除了循环之外,还有更好的方法吗?
original_array[indices]+=updation_values
似乎不起作用,因为它只更新每个唯一索引的最后一个实例。
您的问题有两个:
-
为什么它不能按预期工作?
-
如何让它更快?
为什么它不起作用
您正在创建整数ndarray-这意味着分数更新将四舍五入到最接近的整数,并且由于更新是<0.5,没有明显变化。你可以看到ndarray在创建时为int:工作
original_array=np.arange(5)#, dtype=np.float)
indices=[0,1,2,1,1]
updation_values=[0.2,0.2,0.4, 0.5, 0.4]
updation_values=[1,1,1,1,1]
for update_value, index in zip(updation_values, indices):
original_array[index] += update_value
print(original_array)
[1 4 3 3 4]
要确保创建带有浮动的ndarray,必须使用dtype
可选参数:
original_array=np.arange(5, dtype=np.float)
如何让它更快
通常,您希望使用NumPy矢量化。不幸的是,您的indices
变量中存在重复,因此:
original_array[indices]+=updation_values
不会起作用。
你可以看到自己:
print(original_array)
original_array[[0,1,2]] += 10
print(original_aray)
结果:
[1 4 3 3 4]
[11 14 13 3 4]
也是:
original_array[[0,1,2,0,0]] += 10