在nopython模式下使用Numba的递归函数错误



我想使用 nopython 模式在 Numba 中运行一个递归函数。到目前为止,我只收到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并向元组添加一个新值(在本例中为数字 3)。重复此操作,直到最后一个元组的长度为 5。由于某种原因,这不起作用,不知道为什么。

@njit
def tup(a):
if len(a) == 5:
return a
else:
b = a + (3,)
b = tup(b)
return b

例如,如果a = (0,1),我希望最终结果是元组(0,1,3,3,3)

编辑:我正在使用Numba 0.41.0,我得到的错误是内核死亡,"内核似乎已经死亡。它将自动重新启动。

有几个原因导致你不应该这样做:

  • 这通常是一种在纯 Python 中可能比在 numba 修饰函数中更快的方法。
  • 迭代会更简单,可能更快,但请注意,连接元组通常是一个O(n)操作,即使在 numba 中也是如此。所以功能的整体性能会O(n**2)。这可以通过使用支持追加的数据结构或支持预分配大小的数据结构来改进O(1)。或者干脆不使用"循环"或"递归"方法。
  • 您是否尝试过如果省略njit装饰器并传入包含 6 个元素的元组会发生什么?(提示:它将达到递归限制,因为它永远不会满足递归的结束条件)。

在编写 0.43.1 时,Numba 仅支持简单递归,当参数的类型在递归之间没有变化时。在您的情况下,类型确实发生了变化,您传入了一个tuple(int64 x 2)但递归调用尝试传入一个不同类型的tuple(int64 x 3)。奇怪的是,它在我的计算机上遇到了一个StackOverflow- 这似乎是 numba 中的一个错误。

我的建议是使用它(没有numba,没有递归):

def tup(a):
if len(a) < 5:
a += (3, ) * (5 - len(a))
return a

这也返回预期的结果:

>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)

根据当前版本中的提案列表:

numba 中的递归支持目前仅限于自递归 函数的显式类型注释。此限制来自 无法确定递归调用的返回类型。

因此,请尝试:

from numba import jit
@jit()
def tup(a:tuple) -> tuple:
if len(a) == 5:
return a
return tup(a + (3,))
print(tup((0, 1)))

看看这是否更适合您。

最新更新