在没有循环的情况下更新numpy数组中的多个元素



我有一个numpy数组,比如

original_array=np.arange(5)

我有另一个数组,它存储要更新的索引值

indices=[0,1,2,1,1]

indices的元素中的重复表示该元素被多次更新。

我还有一个数组,它存储要添加的值。

updation_values=[0.2,0.2,0.4, 0.5, 0.4]

通常我会将阵列更新为

for update_value, index in zip(updation_values, indices):
original_array[index]+=update_value

除了循环之外,还有更好的方法吗?

original_array[indices]+=updation_values

似乎不起作用,因为它只更新每个唯一索引的最后一个实例。

您的问题有两个:

  1. 为什么它不能按预期工作?

  2. 如何让它更快?

为什么它不起作用

您正在创建整数ndarray-这意味着分数更新将四舍五入到最接近的整数,并且由于更新是<0.5,没有明显变化。你可以看到ndarray在创建时为int:工作

original_array=np.arange(5)#, dtype=np.float)
indices=[0,1,2,1,1]
updation_values=[0.2,0.2,0.4, 0.5, 0.4]
updation_values=[1,1,1,1,1]
for update_value, index in zip(updation_values, indices):
original_array[index] += update_value
print(original_array)

[1 4 3 3 4]

要确保创建带有浮动的ndarray,必须使用dtype可选参数:

original_array=np.arange(5, dtype=np.float)

如何让它更快

通常,您希望使用NumPy矢量化。不幸的是,您的indices变量中存在重复,因此:

original_array[indices]+=updation_values

不会起作用。

你可以看到自己:

print(original_array)
original_array[[0,1,2]] += 10
print(original_aray)

结果:

[1 4 3 3 4]

[11 14 13 3 4]

也是:

original_array[[0,1,2,0,0]] += 10

最新更新