我正在学习numba,遇到了这种我不理解的"奇怪"行为。 我尝试使用以下代码(在iPython中,用于计时(:
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
def py_len(seq):
return len(seq)
##
t = np.random.rand(1000)
%timeit nb_len(t)
%timeit py_len(t)
结果如下(实际上是由于编译 numba 而导致的第二次运行(:
258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
纯python版本是numba版本的两倍。 我也尝试了签名@nb.njit( nb.int32(nb.float64[:]) )
但结果仍然是一样的。
我在某处犯了错误吗?
谢谢。
不是len((部分增加了时间。 使用输入参数调用 jit 函数会增加开销,这就是您看到的时差。
import numba as nb
def py_pass(i):
return i
@nb.njit()
def nb_pass(i):
return i
%timeit py_pass(1)
%timeit nb_pass(1)
包含输入参数的结果
102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
有趣的是,如果你不需要向 jit 函数传递任何东西,它会更快:
def py_pass():
return 1
@nb.njit()
def nb_pass():
return 1
%timeit py_pass()
%timeit nb_pass()
不带输入参数的结果
96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
正如另一个答案所说,这不是因为在这种情况下len
函数,而是因为对 numba 函数的调用实际上比对普通 Python 函数的调用慢。
是什么让jit
-ted 函数与众不同?
要理解为什么调用 numba jitted 函数的速度较慢,必须了解 numba jitted 函数不再是函数。它是一个调度程序对象:
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
print(nb_len) # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)
此CPUDispatcher
实例表示(可能(基于修饰函数生成的多个已编译函数。
这意味着当您调用CPUDispatcher
实例时,有多个步骤:
- 获取参数的类型。
- 如果没有适合这些类型的参数的编译函数,请使用参数类型编译修饰函数。
- 有时:将参数转换为相应的 numba 类型。
- 调用已编译的函数。
与非修饰函数相比,所有这些步骤都会增加开销。特别是如果没有合适的编译函数并且调度程序需要编译函数 - 或者 - 输入类型需要转换(仅适用于 Python 类型,例如:列表、集合、字典(调用CPUDispatcher
会慢得多 - 这些类型在撰写 numba 0.46 时已被弃用,部分原因是, 请参阅"2.11.2.弃用列表和集类型的反射"。
在您的情况下
在您的情况下,由于编译,对抖动函数的首次调用将明显变慢。
任何后续调用只会稍微慢一点,因为 numba 必须获取参数类型,检查是否已经有一个编译的函数,然后调用该编译的函数。有趣的是,额外的时间取决于参数的数量和该函数已经编译的"重载"的数量。通常,这个额外的时间是微不足道的,因为该函数的作用远不止调用len
。
编译时间
尽管该函数非常简单,但第一次调用时的编译需要花费大量时间:
import numpy as np
import numba as nb
def first_call(seq):
@nb.njit
def nb_len(seq):
return len(seq)
return nb_len(seq)
@nb.njit
def _nb_len(seq):
return len(seq)
def subsequent_calls(seq):
return _nb_len(seq)
t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))
%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
转换时间
此外,如果 numba 需要转换参数,它会慢得多。这只发生在 numba 无法直接处理的 Python 类型上,例如列表:
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
arr = np.random.rand(10_000)
lst = arr.tolist()
nb_len(arr)
nb_len(lst)
%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
总结
- 与普通的 Python 函数相比,Numba 函数有一些额外的开销。因此,请确保你做了numba擅长优化的"足够"的事情,否则一个简单的Python函数将更快,更灵活,更容易调试。
- numba 函数中的函数调用实际上可以不同于 numba 函数之外的函数调用。因此,
nb_len
中的len()
和py_len
中的len()
可以具有完全不同的运行时间。但是,在这种情况下,运行时几乎相同。但通常最好意识到这一点。 - 根据参数类型,numba 函数可能(在幕后(非常慢,特别是如果将 Python 类型作为参数或返回类型处理!