在Jax中该导数为零



实现极坐标到笛卡尔坐标的雅可比矩阵,我在Jax中获得一个零数组,它不能是

theta = np.pi/4
r = 4.0

var = np.array([r, theta])
x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])
def f(var):
return np.array([x, y])

jac = jax.jacobian(f)(var)
jac
#DeviceArray([[0., 0.],
#             [0., 0.]], dtype=float32)

我错过了什么?

你的函数不依赖于var,因为x, y是在函数外定义的。

这将给出期望的输出:

theta = np.pi/4
r = 4.0

var = np.array([r, theta])
def f(var):
x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])
return jnp.array([x, y])

jac = jax.jacobian(f)(var)
jac

请注意,您需要返回一个jax numpy数组,而不是numpy数组。

相关内容

  • 没有找到相关文章

最新更新