我试图得到输出的二阶导数w.r.t使用亚麻构建的神经网络的输入。网络结构如下:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import optim
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.tanh(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([20, 20, 20, 20, 20, 1])
batch = jnp.ones((32, 3)) #Dummy input to Initialize the NN
params = model.init(jax.random.PRNGKey(0), batch)
X = jnp.ones((32, 3))
output = model.apply(params, X)
我可以通过vmap over grad得到单阶导数:
@jit
def u_function(params, X):
u = model.apply(params, X)
return jnp.squeeze(u)
grad_fn = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_X = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
然而,当我再次尝试这样做以获得二阶导数时:
u_X_func = vmap(grad(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))
u_XX_func = vmap(grad(u_X_func, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
我得到以下错误:
[/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py](https://localhost:8080/#) in __call__(self, inputs)
186 kernel = self.param('kernel',
187 self.kernel_init,
--> 188 (jnp.shape(inputs)[-1], self.features),
189 self.param_dtype)
190 if self.use_bias:
IndexError: tuple index out of range
我尝试使用autodiff烹饪书中的hvp定义,但是参数是函数的输入,只是不确定如何进行。
如果有任何帮助,我将非常感激。
问题是您的u_function
将长度为3的向量映射到标量。它的一阶导数是一个长度为3的向量,但它的二阶导数是一个3x3的哈希矩阵,你不能通过jax.grad
来计算,它只适用于标量输出函数。幸运的是JAX提供了jax.hessian
变换来计算这些一般的二阶导数:
u_XX = vmap(hessian(u_function, argnums=1), in_axes=(None, 0), out_axes=(0))(params, X)
print(u_XX.shape)
# (32, 3, 3)