如何在使用for循环时减少JAX编译时间?



这是一个基本的例子。

@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result

当cons很小时,编译时间大约为一分钟。对于更大的缺点,编译时间要高得多——几十分钟。我需要更大的筹码,有什么办法?据我所知,循环是原因。它们在编译时展开。有什么变通办法吗?还有jax.fori_loop。但我不知道怎么用。有jax.experimental.loops模块,但是我还是不能理解它。

我对这一切都很陌生。因此,感谢所有的帮助。如果你能提供一些如何使用jax循环的例子,我将不胜感激。

还有,什么是合适的编译时间?几分钟后可以吗?在其中一个示例中,编译时间为262秒,其余运行时间约为0.1-0.2秒。

运行时的任何增益都会被编译时间所掩盖。

JAX的JIT编译器使所有Python循环扁平化。要理解我的意思,可以看一下在jax.make_jaxpr中运行的这个简单函数,这是一种检查jaxs跟踪程序如何解释python代码的方法(有关更多信息,请参阅理解Jaxprs):

import jax
def f(x):
for i in range(5):
x += i
return x
print(jax.make_jaxpr(f)(0))
# { lambda  ; a.
#   let b = add a 0
#       c = add b 1
#       d = add c 2
#       e = add d 3
#       f = add e 4
#   in (f,) }

注意,循环是扁平的:每一步都成为发送给XLA编译器的显式操作。XLA编译时间会随着函数中操作次数的增加而增加,因此三层嵌套的for循环会导致较长的编译时间是有道理的。

那么,如何解决这个问题?嗯,不幸的是,答案取决于你的--do something--在做什么,所以我猜不出来。

一般来说,最好的选择是使用向量化数组操作,而不是循环遍历这些向量中的值;例如,下面是添加两个向量的非常慢的方法:
import jax.numpy as jnp
def f_slow(x, y):
z = []
for xi, yi in zip(xi, yi):
z.append(xi + yi)
return jnp.array(z)

这里有一个更快的方法来做同样的事情:

def f_fast(x, y):
return x + y

如果你的操作不适合向量化,另一个选择是使用宽松的控制流操作符来代替for循环:这将把循环推入XLA。这在CPU上可以有相当好的性能,但是与等效的矢量数组操作相比,在加速器上速度较慢。

有关JAX和Python控制流语句(如for,if,while等)的更多讨论,请参阅🔪JAX - The Sharp Bits🔪:控制流。

我不确定这是否会与numba相同,但这可能是类似的情况。

当我使用numba.jit编译器并且有大数据输入时,我先在一些小的示例数据上编译函数,然后再使用它。

伪代码:

func_being_compiled(small_amount_of_data)  # compile-only purpose
func_being_compiled(large_amount_of_data)

最新更新