如何使用jax创建物理信息神经网络(PINN)



我正在尝试在JAX中创建一个物理知情的神经网络(PINN)。我想通过输入(x)对定义的模型(神经网络)进行微分,如果我将model设置为jax.grad(params),我会得到一个错误。
如果我将model设置为jax.grad(model),我没有得到错误,但我不知道我是否能够通过x区分神经网络的模型。

class MLP(fnn.Module):
@fnn.compact
def __call__(self, x):
x = fnn.Dense(128)(x)
x = fnn.relu(x)
x = fnn.Dense(256)(x)
x = fnn.relu(x)
x = fnn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

您可以通过以下方式在JAX中区分模型:(1)定义您想要区分的函数,(2)根据您的应用程序使用jax.grad,jax.jacrev,jax.jacfwd等对其进行转换,以及(3)将数据传递给转换后的函数。

从你的问题中你希望区分什么操作并不完全清楚,但这里有一个计算关于参数的训练状态创建的前向雅可比矩阵的例子:

def f(params):
return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
result = jax.jacfwd(f)(params)

如果没有帮助,我建议修改你的问题,以明确你感兴趣的是什么操作。

最新更新