在 JAX 的快速入门教程中,我发现可以使用以下代码行有效地计算可微函数fun
的 Hessian 矩阵:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
但是,也可以通过计算以下内容来计算黑森州:
def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))
这是一个最小的工作示例:
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))
def main():
comp_hessian()
if __name__ == "__main__":
main()
我想知道哪种方法最好以及何时使用?我还想知道是否可以使用grad()
来计算黑森州?grad()
与jacfwd
和jacrev
有何不同?
您的问题的答案在 JAX 文档中;例如,请参阅本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev
引用其对jacrev
和jacfwd
的讨论:
这两个函数计算相同的值(最多为机器数字),但实现方式不同:
jacfwd
使用正向模式自动微分,这对于"高"雅可比矩阵更有效,而jacrev
使用反向模式,这对于"宽"雅可比矩阵更有效。对于近平方矩阵,jacfwd
可能比jacrev
具有优势。
再往下,
为了实现 hessian,我们可以使用
jacfwd(jacrev(f))
或jacrev(jacfwd(f))
或两者的任何其他组合。但正反转通常是最有效的。这是因为在内部雅可比计算中,我们经常区分一个函数宽雅可比(可能像损失函数:Rⁿ→R),而在外部雅可比计算中,我们用平方雅可比计算来区分一个函数(因为∇ :Rⁿ→Rⁿ),这是前向模式胜出的地方。
由于您的函数看起来像:Rⁿ→R,因此jit(jacfwd(jacrev(fun)))
可能是最有效的方法。
至于为什么你不能用grad
实现一个 hessian,这是因为grad
只为具有标量输出的函数的导数而设计。根据定义,黑森是向量值雅可比的组合,而不是标量梯度的组合。