是否有一种方法来加速索引与JAX矢量?



我正在索引向量并使用JAX,但是我注意到在简单索引数组时,与numpy相比,速度要慢得多。例如,考虑在JAX numpy和普通numpy中创建一个基本数组:

import jax.numpy as jnp
import numpy as onp 
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)

然后简单地在两个整数之间建立索引,对于JAX(在GPU上),这给出了时间:

%timeit jax_array[435:852]

1000个循环,最佳5:1.38 ms/loop

对于numpy,它给出的时间为:

%timeit numpy_array[435:852]

1000000个循环,最佳5:271 ns/loop

所以numpy比JAX快5000倍。当JAX在CPU上时,

%timeit jax_array[435:852]

1000个循环,最佳5:577µs/loop

这么快,但仍然比numpy慢2000倍。我使用的是谷歌Colab笔记本电脑,所以安装/CUDA应该不会有问题。

我错过了什么吗?我意识到索引对于JAX和numpy是不同的,正如JAX"锐边"文档所给出的,但我找不到任何方法来执行赋值,如

new_array = jax_array[435:852]

没有明显的减速。我无法避免对数组进行索引,因为这在我的程序中是必要的。

简短的回答是:为了在JAX中加快速度,使用jit

长答案:

您通常应该期望在op-by-op模式下使用JAX的单个操作比在numpy中进行类似操作要慢。这是因为JAX执行有一些固定的每个python函数调用开销,涉及到将编译推到XLA。

甚至像索引这样看似简单的操作都是根据多个XLA操作实现的,这些操作(在JIT之外)每个操作都会增加它们自己的调用开销。您可以使用make_jaxpr变换来查看该函数是如何用基本操作来表示的:

from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda  ; a.
#   let b = broadcast_in_dim[ broadcast_dimensions=(  )
#                             shape=(1,) ] 435
#       c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
#                   indices_are_sorted=True
#                   slice_sizes=(417,)
#                   unique_indices=True ] a b
#       d = broadcast_in_dim[ broadcast_dimensions=(0,)
#                             shape=(417,) ] c
#   in (d,) }

(有关如何阅读本文的信息,请参阅理解Jaxprs)。

JAX优于numpy的地方不是单个小操作(JAX调度开销占主导地位),而是通过jit转换编译的一系列操作。因此,例如,比较jit编译与非jit编译版本的索引:

%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop
f_jit = jit(f)
f_jit(jax_array)  # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop

(注意,由于JAX的异步调度,精确的微基准测试需要block_until_ready())

用jit编译这段代码可以获得150倍的加速。它仍然不如numpy快,因为JAX的调度开销只有几毫秒,但是使用JIT时,这种开销只会产生一次。当您从微基准测试转向更复杂的实际计算序列时,那几毫秒将不再占主导地位,XLA编译器提供的优化可以使JAX比等效的numpy计算快得多。

最新更新