我有一个多敏感阵列,形状为(2、2、3(这样:
array([[[ 0.64, 0.49, 2.56],
[ 7.84, 13.69, 21.16]],
[[ 33.64, 44.89, 57.76],
[ 77.44, 94.09, 112.36]]])
我想找到每行最小的索引。因此,对于此示例,有4个最小值为:0.49、7.84、33.64和77.44。
要获得这些最低限度的索引,我认为这会有效:
idx_arr = np.unravel_index(np.argmin(my_array,axis=2),my_array.shape)
这产生以下索引:
(array([[0, 0],
[0, 0]]), array([[0, 0],
[0, 0]]), array([[1, 0],
[0, 0]]))
但是,最小值未正确计算,正如人们所看到的:
my_array[idx_arr]
array([[0.49, 0.64],
[0.64, 0.64]])
我在那里错过了什么?
Argmin实际上正确地计算值。但是您误解了np.unravel_index
的期望。
来自文档:
将平面索引或一系列平坦索引转换为元组 坐标阵列。
要查看在此处提供所需的输出将接受的输入,我们需要关注要点:它将以非平板的方式将平面数组转换为特定位置的正确坐标数组。从本质上讲,它的预期是您所需点的坐标,就好像您的输入阵列被扁平一样。
import numpy as np
inp = np.array([[[ 0.64, 0.49, 2.56],
[ 7.84, 13.69, 21.16]],
[[ 33.64, 44.89, 57.76],
[ 77.44, 94.09, 112.36]]])
idx = inp.argmin(axis=-1)
#Output:
array([[1, 0],
[0, 0]], dtype=int64)
请注意,您不能直接发送此idx
,因为它不能代表inp
数组的扁平版本的正确坐标。
看起来更像以下内容:
flat_idx = np.arange(0, idx.size*inp.shape[-1], inp.shape[-1]) + idx.flatten()
#Output:
array([1, 3, 6, 9], dtype=int64)
我们可以看到unravel_index
愉快地接受。
temp = np.unravel_index(flat_idx, inp.shape)
#Output:
(array([0, 0, 1, 1], dtype=int64),
array([0, 1, 0, 1], dtype=int64),
array([1, 0, 0, 0], dtype=int64))
inp[temp]
输出:
array([ 0.49, 7.84, 33.64, 77.44])
另外,看看输出元组,我们可以注意到,也很难重新创建同样的东西。请注意,最后一个数组对应于idx
的扁平形式,而前两个数组基本上通过inp
的前两个轴启用索引。
为了准备,我们实际上可以以一种相当漂亮的方式使用unravel_index
功能,如下:
real_idx = (*np.unravel_index(np.arange(idx.size), idx.shape), idx.flatten())
inp[real_idx]
#Output:
array([ 0.49, 7.84, 33.64, 77.44])