为什么每 100k 次迭代打印一次会破坏 numba 性能?



为什么每 100k 次迭代一次(即只打印 40 行!)print的代码需要 50 秒才能运行:

import numpy as np
from numba import jit
@jit
def doit():
A = np.random.random(4*1000*1000)
n = 300
Q = np.zeros(len(A)-n)
for i in range(len(Q)):
Q[i] = np.sum(A[i:i+n] <= A[i+n])
if i % 100000 == 0:  # print the progress once every 100k iterations
print("%i %.2f %% already done. " % (i, i * 100.0 / len(A)))
doit()

而,如果没有print只需要 2.4 秒

import numpy as np
from numba import jit
@jit
def doit():
A = np.random.random(4*1000*1000)
n = 300
Q = np.zeros(len(A)-n)
for i in range(len(Q)):
Q[i] = np.sum(A[i:i+n] <= A[i+n])
doit()

这是一个普遍的事实,print真的可以消除numba的好处吗?

如果您尝试使用@njit@jit(nopython=True)编译它,您将看到它正在从异常中以对象模式编译。这个版本在我的机器上运行大约 1 秒,带有 print 语句:

import numpy as np
from numba import jit
@jit(nopython=True)
def doit():
A = np.random.random(4*1000*1000)
n = 300
Q = np.zeros(len(A)-n)
for i in range(len(Q)):
Q[i] = np.sum(A[i:i+n] <= A[i+n])
if i % 100000 == 0:  # print the progress once every 100k iterations
print(i , "(",  i * 100.0 / len(A), '% already done)')

一般来说,如果你看到 numba 函数的性能很差,那是因为你是在 python 对象模式下编译的,所以总是把nopython=True是一个很好的做法,除非你真的想在 python 对象模式下使用它,因为它会回退到编译器无法编译为机器代码的一些语法。Numba确实做了一些循环提升,但就性能而言,这很难推理。

看:

http://numba.pydata.org/numba-doc/latest/user/5minguide.html#what-is-nopython-mode

相关内容

最新更新