使用numba JIT加速函数的问题



我是numba的jit新手。对于个人项目,我需要加快与下面显示的功能类似的功能,尽管出于编写独立示例的目的有所不同。

import numpy as np
from numba import jit, autojit, double, float64, float32, void
def f(n):
    k=0.
    for i in range(n):
        for j in range(n):
            k+= i+j
def f_with_return(n):
    k=0.
    for i in range(n):
        for j in range(n):
            k+= i+j
    return k
def f_with_arange(n):
    k=0.
    for i in np.arange(n):
        for j in np.arange(n):
            k+= i+j
def f_with_arange_and_return(n):
    k=0.
    for i in np.arange(n):
        for j in np.arange(n):
            k+= i+j  

#jit decorators
jit_f = jit(void(int32))(f)
jit_f_with_return = jit(int32(int32))(f_with_return)
jit_f_with_arange = jit(void(double))(f_with_arange)
jit_f_with_arange_and_return = jit(double(double))(f_with_arange_and_return)

基准:

%timeit f(1000)
%timeit jit_f(1000)

10个环路,最佳3:每个环路73.9 ms/1000000个环路,最好3:212 ns每个环路

%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)

10个环路,3个最佳值:每个环路74.9 ms/100万个环路,每个环路最佳值:220 ns

我不理解这两个:

%timeit f_with_arange(1000.0)
%timeit jit_f_with_arange(1000.0)

10个环路,3个最佳值:每个环路175 ms/1个环路,每个环路 3个最佳:167 ms

%timeit f_with_arange_with_return(1000.0)
%timeit jit_f_with_arange_with_return(1000.0)

10个环路,3个最佳值:每个环路174 ms/1个环路,每个环路 3个最佳:172 ms

我认为我没有为jit函数提供正确的输出和输入类型?仅仅因为for循环现在在numpy.arange上运行,而不再是一个简单的范围,所以我无法使用jit来使它更快。这里的问题是什么?

简单地说,numba不知道如何将np.arange转换为低级别的本地循环,因此它默认为对象层,该对象层的速度要慢得多,通常与纯python的速度相同。

一个不错的技巧是将nopython=True关键字参数传递给jit,看看它是否可以在不使用对象模式的情况下编译所有内容:

import numpy as np
import numba as nb
def f_with_return(n):
    k=0.
    for i in range(n):
        for j in range(n):
            k+= i+j
    return k
jit_f_with_return = nb.jit()(f_with_return)
jit_f_with_return_nopython = nb.jit(nopython=True)(f_with_return)
%timeit f_with_return(1000)
%timeit jit_f_with_return(1000)
%timeit jit_f_with_return_nopython(1000)

最后两个在我的机器上速度相同,比未编译的代码快得多。您有问题的两个示例将引发nopython=True的错误,因为它此时无法编译np.arange

有关更多详细信息,请参阅以下内容:

http://numba.pydata.org/numba-doc/0.17.0/user/troubleshoot.html#the-编译的代码太慢

以及对于支持的numpy功能的列表,其中指示在nopython模式中支持什么和不支持什么:

http://numba.pydata.org/numba-doc/0.17.0/reference/numpysupported.html

最新更新