numpy np.all轴参数的解决方法;与numba的兼容性



我有一个函数,给定一个xy坐标的numpy数组,它会过滤那些位于L侧框内的坐标

import numpy as np
from numba import njit
np.random.seed(65238758)
L = 10
N = 1000
xy = np.random.uniform(0, 50, (N, 2))
box = np.array([
[0,0],  # lower-left
[L,L]  # upper-right
]) 
def sinjit(xy, box):
mask = np.all(np.logical_and(xy >= box[0], xy <= box[1]), axis=1)
return xy[mask]

如果我运行这个函数,它会返回正确的结果:

sinjit(xy, box)
Output: array([[5.53200522, 7.86890708],
[4.60188554, 9.15249881],
[9.072563  , 5.6874726 ],
[4.48976127, 8.73258166],
...
[6.29683131, 5.34225758],
[2.68057087, 5.09835442],
[5.98608603, 4.87845464],
[2.42049857, 6.34739079],
[4.28586677, 5.79125413]])

但是,由于我想通过使用numba在循环中加速此任务,np.all函数中的"axis"参数存在兼容性问题(它不是在nopyson模式下实现的(。所以,我的问题是,有没有可能以任何方式避免这种争论?有什么变通办法吗?

我真的,真的,非常希望numba支持可选的关键字参数。在它出现之前,我几乎忽略了它。然而,这里可能会有一些技巧。

你需要格外小心任何不是二维的或长度可能为零的东西。

import numpy as np
from numba import njit
@njit(cache=True)
def np_all_axis0(x):
"""Numba compatible version of np.all(x, axis=0)."""
out = np.ones(x.shape[1], dtype=np.bool8)
for i in range(x.shape[0]):
out = np.logical_and(out, x[i, :])
return out
@njit(cache=True)
def np_all_axis1(x):
"""Numba compatible version of np.all(x, axis=1)."""
out = np.ones(x.shape[0], dtype=np.bool8)
for i in range(x.shape[1]):
out = np.logical_and(out, x[:, i])
return out
@njit(cache=True)
def np_any_axis0(x):
"""Numba compatible version of np.any(x, axis=0)."""
out = np.zeros(x.shape[1], dtype=np.bool8)
for i in range(x.shape[0]):
out = np.logical_or(out, x[i, :])
return out
@njit(cache=True)
def np_any_axis1(x):
"""Numba compatible version of np.any(x, axis=1)."""
out = np.zeros(x.shape[0], dtype=np.bool8)
for i in range(x.shape[1]):
out = np.logical_or(out, x[:, i])
return out
if __name__ == '__main__':
x = np.array([[1, 1, 0, 0], [1, 0, 1, 0]], dtype=bool)
np.testing.assert_array_equal(np.all(x, axis=0), np_all_axis0(x))
np.testing.assert_array_equal(np.all(x, axis=1), np_all_axis1(x))
np.testing.assert_array_equal(np.any(x, axis=0), np_any_axis0(x))
np.testing.assert_array_equal(np.any(x, axis=1), np_any_axis1(x))

我不确定这会有多高的性能,但如果你真的需要在更高级别的jit’ed函数中调用该函数,那么这将允许你这样做。

numpy不支持任何可选标志。all((由numba支持:https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

如果你坚持使用numba,唯一的方法就是用另一种方式进行编码。

我遇到了同样的问题,决定做一些实验。我发现,如果轴axis可以是一个整数文字(即它提前已知,不需要从变量中检索(,则是一个兼容Numba的替代方案。话虽如此,我还发现这些解决方案在我的测试函数中使用JIT编译时速度较慢,所以如果你想使用它,一定要对你的函数进行基准测试,以确保有实际的净改进。

正如其他人所指出的,Numba不支持包括np.all在内的几个NumPy函数的axis参数。我想到的第一个可能的解决方案是np.amin(又名np.ndarray.min(:对于布尔数组,np.all(a, axis=axis)np.amin(a, axis=axis)相同,对于数值数组,则与np.amin(a, axis=axis).astype('bool')相同。不幸的是,np.amin在不支持axis参数的函数列表中也是。但是,np.argmin确实支持axis参数,np.take_along_axis也是。

因此,np.all(a, axis=axis)可以替换为

对于数字数组:

np.take_along_axis(a, np.expand_dims(np.argmin(a, axis=axis), axis), axis)[(:, ){axis}0].astype('bool')

对于布尔数组:

np.take_along_axis(a, np.expand_dims(np.argmin(a.astype('int64'), axis=axis), axis), axis)[(:, ){axis}0]

  • 我不知道为什么需要.astype('int64'),所以我提交了一份错误报告

分离的部分(:, ){axis}应替换为axis重复:,,以便消除正确的轴。例如,如果a是一个布尔数组,而axis2,则可以使用

CCD_ 28。

基准

关于这一点,我所能说的是,如果真的在一个函数中需要一个numpy.all替代方案,而这个函数总体上会从JIT编译中受益匪浅,那么这个解决方案是合适的。如果你真的只是想加快all本身的速度,你不会有太大的运气。

测试.py

import numba
import numpy as np

# @numba.njit  # raises a TypingError
def using_all():
n = np.arange(10000).reshape((-1, 5))  # numeric array
b = n < 4888  # boolean array
return (np.all(n, axis=1),
np.all(b, axis=1))

# @numba.njit  # raises a TypingError
def using_amin():
n = np.arange(10000).reshape((-1, 5))  # numeric array
b = n < 4888  # boolean array
return (np.amin(n, axis=1).astype('bool'),
np.amin(b, axis=1))

@numba.njit  # doesn't raise a TypingError
def using_take_along_axis():
n = np.arange(10000).reshape((-1, 5))  # numeric array
b = n < 4888  # boolean array
return (np.take_along_axis(n, np.expand_dims(np.argmin(n, axis=1), 1), 1)[:, 0].astype('bool'),
np.take_along_axis(b, np.expand_dims(np.argmin(b.astype('int64'), axis=1), 1), 1)[:, 0])

if __name__ == '__main__':
a = using_all()
m = using_amin()
assert np.all(a[0] == m[0])
assert np.all(a[1] == m[1])
t = using_take_along_axis()
assert np.all(a[0] == t[0])
assert np.all(a[1] == t[1])
PS C:> python -m timeit -n 10000 -s 'from test import using_all; using_all()' 'using_all()'           
10000 loops, best of 5: 32.9 usec per loop
PS C:> python -m timeit -n 10000 -s 'from test import using_amin; using_amin()' 'using_amin()'
10000 loops, best of 5: 43.5 usec per loop
PS C:> python -m timeit -n 10000 -s 'from test import using_take_along_axis; using_take_along_axis()' 'using_take_along_axis()'
10000 loops, best of 5: 55.4 usec per loop

最新更新