这个numba函数中的错误是关于什么的



我已经编写了这个python函数,我相信它将被移植到numba。不幸的是,它没有,我不确定我是否理解错误:Invalid use of getiter with parameters (none)

它需要知道发电机的类型吗?是因为它返回可变长度的元组吗?

from numba import njit
# @njit
def iterator(N, k):
r"""Numba implementation of an iterator over tuples of N integers,
such that sum(tuple) == k.
Args:
N (int): number of elements in the tuple
k (int): sum of the elements
Returns:
tuple(int): a tuple of N integers
"""
if N == 1:
yield (k,)
else:
for i in range(k+1):
for j in iterator(N-1, k-i):
yield (i,) + j

编辑

感谢Jerome的提示。这是我最终写的解决方案(我从左边开始(:

import numpy as np
from numba import njit
@njit
def next_lst(lst, i, reset=False):
r"""Computes the next list of indices given the current list
and the current index.
"""
if lst[i] == 0:
return next_lst(lst, i+1, reset=True)
else:
lst[i] -= 1
lst[i+1] += 1
if reset:
lst[0] = np.sum(lst[:i+1])
lst[1:i+1] = 0
i = 0
return lst, i
@njit
def generator(N, k):
r"""Goes through all the lists of indices recursively.
"""
lst = np.zeros(N, dtype=np.int64)
lst[0] = k
i = 0
yield lst
while lst[-1] < k:
lst, i = next_lst(lst, i)
yield lst

这给出了正确的结果,它被篡改了!

for lst in generator(4,2):
print(lst)
[2 0 0 0]
[1 1 0 0]
[0 2 0 0]
[1 0 1 0]
[0 1 1 0]
[0 0 2 0]
[1 0 0 1]
[0 1 0 1]
[0 0 1 1]
[0 0 0 2]

一个问题来自可变大小元组输出。事实上,元组就像Numba中具有不同类型的结构。它们与列表(而不是Python(非常不同(AFAIK,在Python中,元组大致只是不能变异的列表(。在Numba中,一个由1项和2项组成的元组是两种不同的类型。它们不能统一为更通用的类型。问题是函数的返回falue必须是唯一的类型。因此,Numba拒绝在nopyson模式下编译函数。在Numba中解决此问题的唯一方法是使用列表

话虽如此,即使有列表,也会报告错误。文件指出:

支持大多数递归调用模式。唯一的限制是递归被调用者必须有一个返回时不递归的控制流路径

我认为这里没有满足这个限制,因为没有返回语句。话虽如此,函数应该隐式返回一个生成器(其类型取决于…递归函数本身(。还要注意,生成器的支持是相当新的,递归生成器没有得到很好的支持似乎是合理的。我建议你在Numba github上打开一个问题,因为我不确定这是预期的行为。

请注意,在不使用递归的情况下实现此函数可能更高效。顺便说一句,如果这个函数是从Numba函数而不是CPython调用的,那么它肯定会更快。

相关内容

最新更新