我有一个数据数组a:
>>>a
array([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)
现在我想对数据做这样的处理:
>>>for n in range (1,4):
>>> a += 2
>>> print(a)
[[3. 4. 5.]
[6. 7. 8.]]
[[ 5. 6. 7.]
[ 8. 9. 10.]]
[[ 7. 8. 9.]
[10. 11. 12.]]
,它将给出a的最终结果:
>>> a
array([[ 7., 8., 9.],
[10., 11., 12.]], dtype=float32)
如果我想在每个元素的值为>8,这可能会给我这样的最终结果:
>>> a
array([[7., 8., 7.],
[8., 7., 8.]], dtype=float32)
怎么做?
谢谢!
一种方法是:
import numpy as np
a = np.array([[1., 2., 3.],
[4., 5., 6.]], dtype=np.float32)
for n in range(1, 4):
a += 2 * ((a + 2) <= 8)
print(a)
[[7. 8. 7.]
[8. 7. 8.]]
这个想法是将2乘以一个布尔掩码,如果2可以加到a
的元素上,则该掩码为1,否则为0。
作为一种选择,您可以通过执行以下操作完全跳过for循环:
a = 8 - (a % 2)
print(a)
[[7. 8. 7.]
[8. 7. 8.]]
上面的解决方案是基于这样一个事实,即a中的偶数最终为8,奇数最终为7,这当然是假设a
中的所有数字都小于8.
使用np.where
a = np.array([[1., 2., 3.],
[4., 5., 6.]], dtype=np.float32)
for i in range(1, 4):
a = np.where(a + 2 <= 8, a+2, a)
(需要Python 3.8+)
可以使用Walrus操作符来避免两次计算+ 2:
for i in range(1, 4):
a = np.where((p:=a + 2) <= 8, p, a)
print(a)
array([[7., 8., 7.],
[8., 7., 8.]], dtype=float32)
Syntax :numpy.where(condition[, x, y])
Parameters:
condition : When True, yield x, otherwise yield y.