计算亚麻神经网络输出与输入的Hessian向量积



我试图得到输出的二阶导数w.r.t使用亚麻构建的神经网络的输入。网络结构如下:

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.tanh(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X =  jnp.ones((32, 3))
output = model.apply(params, X)

我可以通过vmap over grad得到单阶导数:

@jit
def u_function(params, X):
u = model.apply(params, X)
return jnp.squeeze(u)
grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

然而,当我再次尝试这样做以获得二阶导数时:

u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)

我得到以下错误:

[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
186     kernel = self.param('kernel',
187                         self.kernel_init,
--> 188                         (jnp.shape(inputs)[-1], self.features),
189                         self.param_dtype)
190     if self.use_bias:
IndexError: tuple index out of range

我尝试使用autodiff烹饪书中的hvp定义,但是参数是函数的输入,只是不确定如何进行。

如果有任何帮助,我将非常感激。

问题是您的u_function将长度为3的向量映射到标量。它的一阶导数是一个长度为3的向量,但它的二阶导数是一个3x3的哈希矩阵,你不能通过jax.grad来计算,它只适用于标量输出函数。幸运的是JAX提供了jax.hessian变换来计算这些一般的二阶导数:

u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)

最新更新