两种矩阵元素指数化方法的相容性



我有两种在jnp = jax.numpy中为矩阵求幂的方法。A.简单的一个:

jnp.exp(-X/reg)

还有一些额外的行动:

def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)

然而,当我测试它们时:

%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()

尽管表面上有一些额外的开销,但第二种方法的表现要好。我运行了一个%timeit,矩阵大小为2000 x 2000:

7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

为什么会出现这种情况?

这里的区别在于操作顺序。

jnp.exp(-X/reg)中,您将否定X的每个条目,然后将结果的每个条目除以reg。这是对数组X的两次传递。

exp_reg中,您正在否定reg(可能是标量值?(,然后将X除以结果。这是X的一次传球。

如果X很大,我预计第一种方法会比第二种稍慢,因为X上有多次通过。

幸运的是,由于您使用的是JAX,所以可以jit编译您的代码,在这种情况下,XLA通常可以在类似的操作顺序上进行优化。事实上,对于您的两个函数,编译消除了差异:

from jax import jit
import jax.numpy as jnp
import numpy as np
def exp_reg1(X, reg):
return jnp.exp(-X/reg)
def exp_reg2(X, reg):
K = jnp.divide(X, -reg)
return jnp.exp(K)
X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0
%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop
# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)
%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop

(附带说明:在将运算结果分配给同名变量之前,没有理由预先分配一个空数组K(。

相关内容

最新更新