Numba jitted len() is slower than pure Python len()



我正在学习numba,遇到了这种我不理解的"奇怪"行为。 我尝试使用以下代码(在iPython中,用于计时(:

import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
def py_len(seq):
return len(seq)
##
t = np.random.rand(1000)
%timeit nb_len(t)
%timeit py_len(t)

结果如下(实际上是由于编译 numba 而导致的第二次运行(:

258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

纯python版本是numba版本的两倍。 我也尝试了签名@nb.njit( nb.int32(nb.float64[:]) )但结果仍然是一样的。

我在某处犯了错误吗?

谢谢。

不是len((部分增加了时间。 使用输入参数调用 jit 函数会增加开销,这就是您看到的时差。

import numba as nb
def py_pass(i):
return i
@nb.njit()
def nb_pass(i):
return i
%timeit py_pass(1)
%timeit nb_pass(1)

包含输入参数的结果

102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

有趣的是,如果你不需要向 jit 函数传递任何东西,它会更快:

def py_pass():
return 1
@nb.njit()
def nb_pass():
return 1
%timeit py_pass()
%timeit nb_pass()

不带输入参数的结果

96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

正如另一个答案所说,这不是因为在这种情况下len函数,而是因为对 numba 函数的调用实际上比对普通 Python 函数的调用慢。

是什么让jit-ted 函数与众不同?

要理解为什么调用 numba jitted 函数的速度较慢,必须了解 numba jitted 函数不再是函数。它是一个调度程序对象:

import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
print(nb_len)  # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)

CPUDispatcher实例表示(可能(基于修饰函数生成的多个已编译函数。

这意味着当您调用CPUDispatcher实例时,有多个步骤:

  • 获取参数的类型。
  • 如果没有适合这些类型的参数的编译函数,请使用参数类型编译修饰函数。
  • 有时:将参数转换为相应的 numba 类型。
  • 调用已编译的函数。

与非修饰函数相比,所有这些步骤都会增加开销。特别是如果没有合适的编译函数并且调度程序需要编译函数 - 或者 - 输入类型需要转换(仅适用于 Python 类型,例如:列表、集合、字典(调用CPUDispatcher会慢得多 - 这些类型在撰写 numba 0.46 时已被弃用,部分原因是, 请参阅"2.11.2.弃用列表和集类型的反射"。

在您的情况下

在您的情况下,由于编译,对抖动函数的首次调用将明显变慢。

任何后续调用只会稍微慢一点,因为 numba 必须获取参数类型,检查是否已经有一个编译的函数,然后调用该编译的函数。有趣的是,额外的时间取决于参数的数量和该函数已经编译的"重载"的数量。通常,这个额外的时间是微不足道的,因为该函数的作用远不止调用len

编译时间

尽管该函数非常简单,但第一次调用时的编译需要花费大量时间:

import numpy as np
import numba as nb
def first_call(seq):
@nb.njit
def nb_len(seq):
return len(seq)
return nb_len(seq)
@nb.njit
def _nb_len(seq):
return len(seq)
def subsequent_calls(seq):
return _nb_len(seq)
t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))
%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

转换时间

此外,如果 numba 需要转换参数,它会慢得多。这只发生在 numba 无法直接处理的 Python 类型上,例如列表:

import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
arr = np.random.rand(10_000)
lst = arr.tolist()
nb_len(arr)
nb_len(lst)
%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

总结

  • 与普通的 Python 函数相比,Numba 函数有一些额外的开销。因此,请确保你做了numba擅长优化的"足够"的事情,否则一个简单的Python函数将更快,更灵活,更容易调试。
  • numba 函数中的函数调用实际上可以不同于 numba 函数之外的函数调用。因此,nb_len中的len()py_len中的len()可以具有完全不同的运行时间。但是,在这种情况下,运行时几乎相同。但通常最好意识到这一点。
  • 根据参数类型,numba 函数可能(在幕后(非常慢,特别是如果将 Python 类型作为参数或返回类型处理!

最新更新