Numpy:np.where(),带有新的维度



Python/numpy初学者,所以这应该很容易解决。给定浮点map的numpy 2d数组,例如

map = [[0.19982308 0.19982308 0.19986019 ... 0.25456086 0.25463998 0.25463998]
[0.19982308 0.19982308 0.19986019 ... 0.25456086 0.25463998 0.25463998]
[0.19998285 0.19998285 0.20000038 ... 0.25459546 0.25466287 0.25466287]
...
[0.4762167  0.4762167  0.47602317 ... 0.45300224 0.4541465  0.4541465 ]
[0.4767613  0.4767613  0.47632453 ... 0.45406988 0.45538843 0.45538843]
[0.4767613  0.4767613  0.47632453 ... 0.45406988 0.45538843 0.45538843]]

我想执行此操作:

new_map = np.where(map > 0.4, [255,255,255], [0,0,0])

也就是说,我想创建一个新的二维数组,该数组具有相同的维度,但使用RGB值而不是浮点值。哪个RGB值被分配给new_map[x][y]-白色=[255255255]或黑色=[0,0,0]-取决于map[x][y]是否高于阈值(在上述情况下为0.4(。

我收到以下错误消息:operands could not be broadcast together with shapes (512,512) (3,) (3,)

我想我理解了为什么-np.where限制为map的维度,实际上我正试图通过用float替换长度为3的嵌套数组来增加这些维度。

使用where或任何其他numpy操作是否有解决此问题的方法?谢谢

首先将映射转换为np数组:

import numpy as np
#import matplotlib.pyplot as plt
#create map
map = np.random.rand(200,200)
#to show
#plt.matshow(map)

output_var = np.zeros([*map.shape,3])
output_var[map>0.4]=np.array([255,255,255])

存在广播问题,因为numpy.where假定数组具有兼容的形状。假设期望的输出形状是(y, x, 3),您可以执行:

map_reshape = np.expand_dims(map, -1)
white = np.array([255, 255, 255]).reshape(1, 1, -1)
black = np.array([0, 0, 0]).reshape(1, 1, -1)
new_map = np.where(map_reshape >  0.4, white, black)

如果映射的形状是(y, x),则new_map的形状将是(y, x, 3)

最新更新