numpy 仅获取 n 长度子集作为给定 3 点索引批次的 3 点数据批次的最后一个 dim 的方法



任务:给定"值"和"ind"以最笨拙的方式获得"结果"。

输入:

import numpy as np
values = np.reshape(np.array([x/100 for x in range(4*5*10)]), (4, 5, 10))
ind = np.reshape(np.array([np.random.randint(0,10) for x in range(4*5*5)]), (4, 5, 5))

示例所需输出:

result = np.array([[[0.08, 0.02, 0.03, 0.01, 0.  ],
[0.18, 0.15, 0.17, 0.19, 0.17],
[0.29, 0.27, 0.24, 0.27, 0.2 ],
[0.39, 0.37, 0.33, 0.37, 0.3 ],
[0.46, 0.47, 0.48, 0.43, 0.49]],
[[0.56, 0.58, 0.57, 0.55, 0.52],
[0.63, 0.61, 0.63, 0.6 , 0.62],
[0.77, 0.74, 0.73, 0.71, 0.7 ],
[0.88, 0.82, 0.87, 0.82, 0.83],
[0.96, 0.95, 0.93, 0.98, 0.94]],
[[1.08, 1.09, 1.04, 1.02, 1.05],
[1.18, 1.16, 1.15, 1.12, 1.17],
[1.28, 1.29, 1.27, 1.21, 1.27],
[1.38, 1.38, 1.31, 1.35, 1.32],
[1.41, 1.49, 1.42, 1.48, 1.46]],
[[1.59, 1.5 , 1.56, 1.53, 1.51],
[1.6 , 1.69, 1.69, 1.6 , 1.68],
[1.79, 1.73, 1.72, 1.74, 1.77],
[1.84, 1.84, 1.83, 1.88, 1.8 ],
[1.98, 1.99, 1.91, 1.95, 1.92]]])

编辑:我的错,忘了指定随机种子。
编辑:非 Numpyic 版本的代码是:

result_ = np.zeros_like(result)
for batch_idx in range(len(values)):
for word_idx in range(len(values[0])):
result_[batch_idx][word_idx] = values[batch_idx,word_idx, ind[batch_idx, word_idx]]

我认为你需要的是:

import numpy as np
np.random.seed(100)
values = np.reshape(np.array([x/100 for x in range(4*5*10)]), (4, 5, 10))
ind = np.reshape(np.array([np.random.randint(0,10) for x in range(4*5*5)]), (4, 5, 5))
ii = np.arange(values.shape[0])[:, np.newaxis, np.newaxis]
jj = np.arange(values.shape[1])[np.newaxis, :, np.newaxis]
result = values[ii, jj, ind]
print(result)

输出:

[[[0.08 0.08 0.03 0.07 0.07]
[0.1  0.14 0.12 0.15 0.12]
[0.22 0.22 0.21 0.2  0.28]
[0.34 0.3  0.39 0.36 0.32]
[0.44 0.41 0.45 0.43 0.44]]
[[0.54 0.53 0.57 0.51 0.51]
[0.67 0.67 0.6  0.62 0.69]
[0.79 0.73 0.72 0.75 0.78]
[0.81 0.8  0.87 0.86 0.82]
[0.9  0.98 0.92 0.95 0.91]]
[[1.08 1.01 1.05 1.04 1.02]
[1.18 1.13 1.15 1.1  1.19]
[1.23 1.26 1.23 1.24 1.27]
[1.36 1.33 1.39 1.3  1.34]
[1.44 1.45 1.47 1.46 1.46]]
[[1.52 1.54 1.52 1.57 1.51]
[1.66 1.66 1.6  1.67 1.62]
[1.73 1.75 1.74 1.72 1.74]
[1.83 1.87 1.89 1.8  1.8 ]
[1.95 1.99 1.96 1.96 1.95]]]

最新更新