如果有两个函数,一个有jit,另一个没有,当我迭代它们100次时,unjit函数给我的时间比jit函数少


import jax
import numpy as np
import jax.numpy as jnp
a = []
a_jax = []
for i in range(10000):
a.append(np.random.randint(1, 5, (5,)))
a_jax.append(jnp.array(a[i]))
# a_jax = jnp.array(a_jax)
@jax.jit
def calc_add_with_jit(a, b):
return a + b
def calc_add_without_jit(a, b):
return a + b
def main_function_with_jit():
for i in range(99):
calc_add_with_jit(a_jax[i], a_jax[i+1]) 
def main_function_without_jit():
for i in range(99):
calc_add_without_jit(a[i], a[i+1])
%time calc_add_with_jit(a_jax[1], a_jax[2])
%time main_function_with_jit()
%time main_function_without_jit()

现在第一个CCD_ 1导致3.33ms的壁时间,第二个CCD_ 2函数导致5.58ms的时间,第三次%time导致156微秒的时间

有人能解释为什么会发生这种事吗?为什么JAX-JIT与常规代码相比速度较慢?我说的是第二次和第三次函数结果

这个问题在JAX文档中得到了很好的回答;请参阅常见问题解答:JAX比NumPy快吗?特别是,引用摘要:

如果您在CPU上对单个阵列操作进行微基准测试,您通常可以预期NumPy的性能优于JAX,因为它的每操作调度开销较低。如果你在GPU或TPU上运行代码,或者在CPU上测试更复杂的JIT编译的操作序列,你通常可以预期JAX会优于NumPy。

您正在对CPU上单独调度的单个操作序列进行基准测试,这正是NumPy设计和优化的机制,因此您可以预期NumPy会更快。

最新更新