尝试在 JAX 中对雅可比计算进行 JJ 时出错:"ValueError: Non-hashable static arguments are not supported"



这个问题与这里的问题类似,但我无法与我应该改变的内容联系起来。

我有一个功能

def elbo(variational_parameters, eps, a, b):
...
return theta, _
elbo = jit(elbo, static_argnames=["a", "b"])

其中variational_parameters是长度为 P 的向量(一维数组),eps是 K x N 维的二维数组,ab是固定值。

elbo已成功vmapeps,并通过传递abstatic_argnames进行设置jit,以返回theta,这是一个维度为K乘P的二维数组。

我想通过elbo函数获取输出theta的雅可比函数相对于variational_parameters。返回的第一个值

jacobian(elbo, argnums=0, has_aus=True)(variational_parameters, eps, a, b)

给了我一个三维数组,维度为 K x P x N。这就是我想要的。一旦我尝试抖动这个函数

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

我收到错误

ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function elbo is non-hashable.

任何帮助将不胜感激;谢谢!

传递给 JIT 编译函数的任何参数将不再是静态的,除非显式标记它们。所以这行:

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

使variational_parametersepsab非静态。然后在转换后的函数中,这些非静态参数被传递给此函数:

elbo = jit(elbo, static_argnames=["a", "b"])

这意味着您正在尝试将非静态值作为静态参数传递,这会导致错误。

若要解决此问题,应在静态参数进入 jit 编译函数时将其标记为静态。在您的情况下,它可能看起来像这样:

jit(jacobian(elbo, argnums=0, has_aus=True),
static_argnums=(2, 3))(variational_parameters, eps, a, b)

相关内容

最新更新