为什么Numba会扭曲jit编译函数的计时?



我试图对一个Python函数进行基准测试,该函数使用Numba对CPython解释器进行列表操作。为了比较端到端时间,我使用了Linux时间实用程序。time python3.10 list.py

据我所知,由于JIT编译,第一次调用将是昂贵的,但它并没有解释为什么最大记录时间比运行整个脚本所花费的总时间长。

# list.py
import numpy as np
from time import time, perf_counter 
from numba import njit
@njit
def listOperations():
list = []
for i in range(1000):
list.append(i)

list.sort(reverse=True)
list.remove(420)
list.reverse()
if __name__ == "__main__":
repetitions = 1000
timings = np.zeros(repetitions)
for rep in range(repetitions):
start = time()  # Similar results with perf_counter too.
listOperations()
timings[rep] = time() - start
# Convert to milliseconds
timings *= 10e3
print("Mean {}ms, Median {}ms, Std. Dev {}ms, Min {}ms, Max {}ms".format(
float('%.4f' % np.mean(timings)), 
float('%.4f' % np.median(timings)), 
float('%.4f' % np.std(timings)), 
float('%.4f' % np.min(timings)), 
float('%.4f' % np.max(timings)))
)

对于Numba,它显示最大66.3秒,而时间实用程序报告~8s。完整的结果如下:

'''
Numba --->
Mean 66.8154ms, Median 0.391ms, Std. Dev 2097.7752ms, Min 0.3219ms, Max 66371.1143ms
real  0m7.982s
user  0m8.248s
sys   0m0.100s
CPython3.10 --->
Mean 1.6395ms, Median 1.6284ms, Std. Dev 0.0708ms, Min 1.5759ms, Max 2.3198ms
real. 0m1.115s
user  0m1.468s
sys   0m0.080s 
'''

主要问题是编译时间包含在计时中。实际上,Numba是惰性地编译这些函数的。为了防止这种情况,你必须指定原型或者在外部执行第一个函数调用(这在基准测试中通常是一个很好的做法)。

使用@njit('()')而不是@njit。通过这个修复,Numba代码在我的机器上大约快了两倍。

请注意,您的函数不返回任何内容,也不读取任何参数,因此JIT可以将函数优化为无操作。为了避免偏差,您当然需要添加一个参数,使用它并返回列表。这显然不是我的机器上的情况,但不同版本的Numba可能会这样做。

还要注意Numba列表通常不是Numba发光的地方。列表通常很慢(有和没有Numba)。最好在大小已知的情况下使用array.

顺便说一下,list是一个内置函数。覆盖它可能会在使用它的模块中(经常)导致隐蔽的bug,所以这不是一个好主意。我建议你用另一个名字。

此外,请注意,结果中的标准差相当大,中位数时间很好,最大时间非常大,这表明计时不稳定,这种不稳定是由于一个缓慢的调用。这样的结果通常表明基准测试有缺陷,或者函数本身有不稳定的行为(通常是由于错误或一次初始化)。

最新更新