JAX:如何在jit函数中的条件上累积jnp.array



我想在jit函数中过滤一个带有条件的jnp.array,并累加到一个全局变量(所以我们必须使用JAX控制流原语):

import jax
import jax.numpy as jnp
from jax import jit
from jax import lax
key = jax.random.PRNGKey(42)

@jit
def get_data():
data = jax.random.normal(key, (5, 3))
data = data.at[-2:].set(0.)
return data

data = get_data()
accu = data[0]

@jit
def filter(data):
def body_fun(i):
global accu
accu = jnp.vstack((accu, data[i]))
return i + 1
lax.while_loop(lambda i: jnp.all(data[i]), body_fun, 1)
filter(data)

我预计filter执行后accu.shape为(3,3)(数据中有三个非零行),但得到(2,3):

Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/1)>

我怀疑lax.while_loop迭代了第1行和第2行,但全局accu只更新了一次,但为什么?或者有没有更好的方法可以在不使用全局变量的情况下累积jnp.array(在jit函数中)?

您的body_fun使用副作用更新accu。Jax是一个函数式编程库:您需要明确所有更新。这意味着accu应该在body_fun的自变量中,并且在更新后由其返回。

jax.lax.while_loop的签名是jax.lax.while_loop(cond_fun, body_fun, init_val)。在您的情况下,init_val应该是元组(counter, accu)

最后一个问题是accu应该是固定形状的。您需要预先将其初始化为某种形状,在您的情况下,这将是最终accu形状的上限。

最后,下面的代码可以工作了。在这里,我建议不要关闭data,以表明cond_fn也依赖于数据。你可以用一个闭包来代替。

# Initialization
data = get_data()
accu = jax.zeros_like(data)
i = 0
def body_fn(carry): 
i, accu, data = carry
accu = accu.at[i].set(data[i])
return (i + 1, accu, data)
def cond_fn(carry):
i, accu, data = carry
return jnp.all(data[i])
last_i, accu, _ = lax.while_loop(body_fn, (i, accu))

最新更新