使用 Numpy 和 Numba 将值数组装箱到离散集中最接近的值



下面有一个函数,它接受浮点数数组和离散整数数组。对于所有的浮点数,我希望它们四舍五入到列表中最接近的整数。

下面的函数完美地工作,其中sHatV是一个包含10,000个浮点数的数组,而possible_locations是一个包含5个整数的数组:

binnedV = [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]

由于这个函数将被调用数千次,我试图使用@numba.njit装饰器来最小化计算时间。

我想在我的"麻木"函数中使用np.digitize,但它将值舍入到零。我希望所有的东西都被归为可能位置的一个值。

总的来说,我需要编写一个numba兼容的函数,它接受长度为N的第一个数组中的每个值,在数组2中找到与它最接近的值,并返回最接近的值,最终得到一个长度为N的数组,其中包含存储值。

任何帮助都是感激的!

这是一个运行速度快得多的版本,可能更"麻木";因为它使用numpy函数而不是隐式的列表推导式for循环:

import numpy as np
sHatV = [0.33, 4.18, 2.69]
possible_locations = np.array([0, 1, 2, 3, 4, 5])
diff_matrix = np.subtract.outer(sHatV, possible_locations)
idx = np.abs(diff_matrix).argmin(axis=1)
result = possible_locations[idx]
print(result)
# output: [0 4 3]

这里的想法是计算sHatvpossible_locations之间的差矩阵。在这个特殊的例子中,这个矩阵是:

array([[ 0.33, -0.67, -1.67, -2.67, -3.67, -4.67],
[ 4.18,  3.18,  2.18,  1.18,  0.18, -0.82],
[ 2.69,  1.69,  0.69, -0.31, -1.31, -2.31]])

然后,使用np.abs( ... ).argmin(axis=1),我们找到绝对差最小的每一行的索引。如果我们用这些索引索引原来的possible_locations数组,我们就得到了答案。

比较运行时间:

使用列表推导式

def f(possible_locations, sHatV):
return [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]

def test_f():
possible_locations = np.array([0, 1, 2, 3, 4, 5])
sHatV = np.random.uniform(0.1, 4.9, size=10_000)
f(possible_locations, sHatV)

%timeit test_f()
# 187 ms ± 7.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

使用差分矩阵

def g(possible_locations, sHatV):
return possible_locations[np.abs(np.subtract.outer(sHatV, bins)).argmin(axis=1)]

def test_g():
possible_locations = np.array([0, 1, 2, 3, 4, 5])
sHatV = np.random.uniform(0.1, 4.9, size=10_000)
g(possible_locations, sHatV)
%timeit test_g()
# 556 µs ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

您可以使用numpy的np.searchSorted()函数。np.digitize()本身是根据np.searchSorted()实现的。例如,

import numpy as np
offset = 1e-8
indices = np.searchsorted(possible_locations, sHatV - offset)
return possible_locations[np.clip(indices, 0, len(int) - 1)]

我建议坚持使用numpy。digitize函数接近您所需要的,但需要进行一些修改:

  • 实现舍入逻辑而不是floor/ceiling
  • 帐户端点问题。文件显示:If values in `x` are beyond the bounds of `bins`, 0 or ``len(bins)`` is returned as appropriate.

下面是一个例子:

import numpy as np
sHatV = np.array([-99, 1.4999, 1.5, 3.1, 3.9, 99.5, 1000])
bins = np.arange(0,101)
def custom_round(arr, bins):
bin_centers = (bins[:-1] + bins[1:])/2 
idx = np.digitize(arr, bin_centers)
result = bins[idx]
return result
assert np.all(custom_round(sHatV, bins) == np.array([0, 1, 2, 3, 4, 100, 100]))

现在是我最喜欢的部分:numpy的速度有多快?我不做缩放,我们只选择大数组:

sHatV = 10009*np.random.random(int(1e6))
bins = np.arange(10000)
%timeit custom_round(sHatV, bins)
# on a laptop: 100 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

最新更新