传递一个形状给numpy.在numba njit环境中重塑失败,我如何为目标形状创建合适的可迭代对象?



我有一个函数,它接受一个数组,执行一个任意计算并返回一个可以广播的新形状。我想在numba.njit环境中使用此函数:

import numpy as np
import numba as nb
@nb.njit
def generate_target_shape(my_array):
### some functionality that calculates the desired target shape ###
return tuple([2,2])

@nb.njit
def test():
my_array = np.array([1,2,3,4])
target_shape = generate_target_shape(my_array)
reshaped = my_array.reshape(target_shape)
print(reshaped)
test()

然而,在numba中不支持元组创建,当尝试使用tuple()操作符将generate_target_shape的结果强制转换为元组时,我得到以下错误消息:

No implementation of function Function(<class 'tuple'>) found for signature:

>>> tuple(list(int64)<iv=None>)

There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
With argument(s): '(list(int64)<iv=None>)':
No match.
During: resolving callee type: Function(<class 'tuple'>

如果我试图将generate_target_shape的返回类型从tuple更改为listnp.array,我收到以下错误消息:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

是否有一种方法可以让我在nb.njit函数中创建一个可迭代的对象,可以传递给np.reshape?

编辑:我通过使用objmode构造函数来解决这个问题。

numba似乎不支持标准python函数tuple()。您可以通过稍微重写代码来轻松解决此问题:

import numpy as np
import numba as nb
@nb.njit
def generate_target_shape(my_array):
### some functionality that calculates the desired target shape ###
a, b = [2, 2] # (this will also work if the list is a numpy array)
return a, b
然而,一般情况要棘手得多。我将回溯我在评论中所说的话:不可能或不建议使用numba编译函数来处理许多不同大小的元组。这样做需要为每个具有唯一大小的元组重新编译函数。@Jérôme Richard在stackoverflow的答案中很好地解释了这个问题。

我建议您做的是,简单地使用包含形状和数据的数组,并在numba编译函数之外计算my_array.reshape(tuple(target_shape))。它不是很漂亮,但它可以让你继续你的项目。

最新更新