这个问题与这里的问题类似,但我无法与我应该改变的内容联系起来。
我有一个功能
def elbo(variational_parameters, eps, a, b):
...
return theta, _
elbo = jit(elbo, static_argnames=["a", "b"])
其中variational_parameters
是长度为 P 的向量(一维数组),eps
是 K x N 维的二维数组,a
、b
是固定值。
elbo
已成功vmap
eps
行,并通过传递a
和b
来static_argnames
进行设置jit
,以返回theta
,这是一个维度为K乘P的二维数组。
我想通过elbo
函数获取输出theta
的雅可比函数相对于variational_parameters
。返回的第一个值
jacobian(elbo, argnums=0, has_aus=True)(variational_parameters, eps, a, b)
给了我一个三维数组,维度为 K x P x N。这就是我想要的。一旦我尝试抖动这个函数
jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)
我收到错误
ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function elbo is non-hashable.
任何帮助将不胜感激;谢谢!
传递给 JIT 编译函数的任何参数将不再是静态的,除非显式标记它们。所以这行:
jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)
使variational_parameters
、eps
、a
和b
非静态。然后在转换后的函数中,这些非静态参数被传递给此函数:
elbo = jit(elbo, static_argnames=["a", "b"])
这意味着您正在尝试将非静态值作为静态参数传递,这会导致错误。
若要解决此问题,应在静态参数进入 jit 编译函数时将其标记为静态。在您的情况下,它可能看起来像这样:
jit(jacobian(elbo, argnums=0, has_aus=True),
static_argnums=(2, 3))(variational_parameters, eps, a, b)