numba.errors.TypingError: 在 nopython 模式下失败 管道(步骤:nopython 前端) 无法确定变量'argmax'的类型



我试图通过使用numba来加速我的python代码。但经过几天的尝试和数百条错误消息,我仍然没能成功。

我目前的问题是这个错误消息:

Traceback (most recent call last):
File "E:/Studium/Masterarbeit/Masterarbeit/code/fast_simulation.py", line 152, in <module>
epsis, ks, vn, ln = monte_carlo(n, m, alpha, epsilon_max, delta, aufloesung, fehlerposition)
File "C:Pythonlibsite-packagesnumbadispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:Pythonlibsite-packagesnumbadispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:Pythonlibsite-packagesnumbasix.py", line 668, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Type of variable 'argmax' cannot be determined, operation: $300unpack_sequence.5, location: E:/Studium/Masterarbeit/Masterarbeit/code/fast_simulation.py (126)
File "fast_simulation.py", line 126:
def monte_carlo(n: int, m: int, alpha: float, epsilon_max: float, delta: float, aufloesung: int, x_position: float):
<source elided>
argmax, max_value = get_max_ks(random_values)

我的代码是

@njit
def insort(a, x, lo=0, hi=None):
"""Insert item x in list a, and keep it sorted assuming a is sorted.
If x is already in a, insert it to the right of the rightmost x.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]:
hi = mid
else:
lo = mid+1
a.insert(lo, x)
@njit
def get_max_ks(data: list) -> (float, float):
def f(x, data):
return sqrt(len(data)) * abs(np.searchsorted(data, x, side='right') / n - x)
max_value = -1000
argmax = 0.0
epsilon = 1 / (n * 10 ** 6)
for i in range(n):
data_i = data[i]
fx1 = f(data_i, data)
fx2 = f(data_i + epsilon, data)
fx3 = f(data_i - epsilon, data)
if fx1 > max_value:
max_value = fx1
argmax = data_i
if fx2 > max_value:
max_value = fx2
argmax = data_i + epsilon
if fx3 > max_value:
max_value = fx3
argmax = data_i - epsilon
return argmax, max_value
@njit
def get_max_vn(data: list) -> (float, float):
def g(x, data):
return sqrt(len(data)) * abs(np.searchsorted(data, x, side='right') / n - x) / sqrt(
x * (1 - x))
max_value = -1000
argmax = 0.0
epsilon = 1 / (n * 10 ** 6)
for i in range(n):
data_i = data[i]
fx1 = g(data_i, data)
fx2 = g(data_i + epsilon, data)
fx3 = g(data_i - epsilon, data)
if fx1 > max_value:
max_value = fx1
argmax = data_i
if fx2 > max_value:
max_value = fx2
argmax = data_i + epsilon
if fx3 > max_value:
max_value = fx3
argmax = data_i - epsilon
return argmax, max_value

@njit
def get_critical_value_vn(alpha: float, n: int) -> float:
loglogn = log(log(n))
an = sqrt(2 * loglogn)
dn = 2 * loglogn + 0.5 * log(loglogn) - 0.5 * log(pi)
return (dn - log(-0.5 * log(1 - alpha))) / an

@njit
def monte_carlo(n: int, m: int, alpha: float, epsilon_max: float, delta: float, aufloesung: int, x_position: float):
epsilons = np.linspace(min(0.0, epsilon_max), max(0.0, epsilon_max), aufloesung)
res_ks = np.zeros(n)
res_vn = np.zeros(n)
res_ln = np.zeros(n)
ks_critical_value = 1.2238478702170836  # TODO
vn_critical_value = get_critical_value_vn(alpha, n)
ln_critical_value = 2.7490859400901955  # TODO
for epsilon in epsilons:
for i in range(m):
uniform_distributed_values = np.random.uniform(0.0, 1.0, n)
random_values = []
for x in uniform_distributed_values:
if x < max(0, x_position - delta) or x > min(1, x_position + delta):
insort(random_values, x)
elif max(0, x_position - delta) <= x and x <= x_position + epsilon:
insort(random_values,
(x - x_position - epsilon) / (1 + (epsilon / min(delta, x_position))) + x_position)
else:
insort(random_values,
(x - x_position - epsilon) / (
1 - (epsilon / min(delta, 1 - x_position))) + x_position)
argmax, max_value = get_max_ks(random_values)
vn_argmax, vn_max = get_max_vn(random_values)
ks_statistic = max_value
vn_statistic = vn_max
ln_statistic = max_value / sqrt(argmax * (1 - argmax))
if ks_statistic > ks_critical_value:  # if test dismisses H_0
res_ks[i] += 1 / m
if vn_statistic > vn_critical_value:  # if test dismisses H_0
res_vn[i] += 1 / m
if ln_statistic > ln_critical_value:  # if test dismisses H_0
res_ln[i] += 1 / m
return epsilons, res_ks, res_vn, res_ln

if __name__ == '__main__':
# some code
epsis, ks, vn, ln = monte_carlo(n, m, alpha, epsilon_max, delta, aufloesung, fehlerposition)
# some other code

知道我该怎么解决这个问题吗?

我在numba github页面上为这个问题打开了一个问题。问题是numba无法确定列表random_values的类型。解决方案是使用如下键入的列表:

random_values = typed.List.empty_list(types.float64)

最新更新