JAX和Sympy中的导数不一致



对于这个向量函数,我想求雅可比矩阵:

import jax
import jax.numpy as jnp
def myf(arr, phi_0, phi_1, phi_2, lambda_0, R):
arr = jnp.deg2rad(arr)
phi_0 = jnp.deg2rad(phi_0)
phi_1 = jnp.deg2rad(phi_1)
phi_2 = jnp.deg2rad(phi_2)
lambda_0 = jnp.deg2rad(lambda_0)

n = jnp.sin(phi_1)

F = 2.0
rho_0 = 1.0
rho = R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n
x_L = rho*jnp.sin(n*(arr[0] - lambda_0))
y_L = rho_0 - rho*jnp.cos(n*(arr[0] - lambda_0))

return jnp.array([x_L,y_L])

arr = jnp.array([-18.1, 29.9])
jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)

我获得

[[ 0.01312758  0.00014317]
[-0.00012411  0.01514319]]

我对这些值感到震惊。以元素[0][0], 0.01312758为例。我们知道它是x_L关于变量arr[0]的偏导。无论是手工还是使用sympy,导数都是~0.75。

from sympy import *
x, y = symbols('x y')
x_L = (2.0*(1/tan(3.141592/4 + y/2))**0.492)*sin(0.492*(x + 0.2967))
deriv = Derivative(x_L, x)
deriv.doit()
deriv.doit().evalf(subs={x: -0.3159, y: 0.52})
0.752473089673695

(插入x, y,即已经用弧度表示的arr[0]arr[1])。这也是我手工得到的结果。Jax结果发生了什么?我看不出我做错了什么。

JAX代码段输入角度,因此其梯度的单位为1/度,而sympy代码段输入弧度,因此其梯度的单位为1/弧度。如果您将jax输出转换为1/弧度(即将jax输出乘以180/pi),您将得到您正在寻找的结果:

result = jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)
print(result * 180 / jnp.pi)
[[ 0.7521549   0.00820279]
[-0.00711098  0.8676407 ]]

或者,您可以重写myf以接受以弧度为单位的输入,并通过直接取其梯度来获得预期的结果。

好了,我想我知道是怎么回事了…这是微妙的。

问题是在函数内部完成的从度到度的转换是而不是对jax有效。我认为(但肯定有比我知道的人),jax做衍生一旦jax.jacobian(myf)被调用,它只在最后评估,当值传递(懒惰的评估,我认为),所以函数内值的转换不做任何事情。正确的代码是
def myf(arr, phi_0, phi_1, phi_2, lambda_0, R):

n = jnp.sin(phi_1)

F = 2.0
rho_0 = 1.0
rho = R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n
x_L = (R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n) *jnp.sin(n*(arr[0] - lambda_0))
y_L = rho_0 - (R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n) *jnp.cos(n*(arr[0] - lambda_0))

return jnp.array([x_L,y_L])

arr = jnp.array([-18.1, 29.9])
jax.jacobian(myf)(jnp.deg2rad(arr), jnp.deg2rad(29.5),
jnp.deg2rad(29.5), jnp.deg2rad(29.5), jnp.deg2rad(-17.0),
R=1)
# [[ 0.7521549   0.00820279]
#  [-0.00711098  0.8676407 ]]

相关内容

  • 没有找到相关文章

最新更新