Numba没有提高性能



我正在测试一些使用numpy数组的函数的numba性能,并比较:

import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings
warnings.simplefilter('ignore', category=NumbaWarning)
@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a):     # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]):   # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace              # Numba likes NumPy broadcasting

class Main(object):
def __init__(self) -> None:
super().__init__()
self.mat     = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)
def my_run(self):
st = time.time()
trace = 0.0
for i in range(self.mat.shape[0]):   
trace += np.tanh(self.mat[i, i]) 
res = self.mat + trace
print('Python Diration: ', time.time() - st)
return res                           

def jit_run(self):
st = time.time()
res = go_fast(self.mat)
print('Jit Diration: ', time.time() - st)
return res

obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()

输出为:

Python Diration:  0.2164750099182129
Jit Diration:  0.5367801189422607

如何获得此示例的增强版本?

Numba实现的执行时间较慢是由于编译时间,因为Numba在使用函数时编译函数(只有第一次,除非参数类型发生变化(。它这样做是因为在调用函数之前无法知道参数的类型。希望您可以为Numba指定参数类型,这样它就可以直接编译函数(当decorator函数执行时(。这是生成的代码:

@njit('float64[:,:](float64[:,:])')
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace

请注意,njitjit+nopython=True的快捷方式,默认情况下boundscheck已设置为False(请参阅文档(。

在我的机器上,这导致Numpy和Numba的执行时间相同。实际上,执行时间不受tanh函数计算的限制。它由表达式a + trace(对于Numba和Numpy(限定。预计执行时间相同,因为两者的实现方式相同:它们创建一个临时的新数组来执行添加。由于页面故障和RAM的使用,创建新的临时阵列是昂贵的(a从RAM中完全读取,临时阵列完全存储在RAM中(。如果您想要更快的计算,则需要就地执行操作(这样可以防止x86平台上出现页面错误和昂贵的缓存线写入分配(。

最新更新