在 JAX 中高效计算黑森矩阵

  • 本文关键字:计算 JAX 高效 jax
  • 更新时间 :
  • 英文 :


在 JAX 的快速入门教程中,我发现可以使用以下代码行有效地计算可微函数fun的 Hessian 矩阵:

from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))

但是,也可以通过计算以下内容来计算黑森州:

def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))

这是一个最小的工作示例:

import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))

def main():
comp_hessian()

if __name__ == "__main__":
main()

我想知道哪种方法最好以及何时使用?我还想知道是否可以使用grad()来计算黑森州?grad()jacfwdjacrev有何不同?

您的问题的答案在 JAX 文档中;例如,请参阅本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

引用其对jacrevjacfwd的讨论:

这两个函数计算相同的值(最多为机器数字),但实现方式不同:jacfwd使用正向模式自动微分,这对于"高"雅可比矩阵更有效,而jacrev使用反向模式,这对于"宽"雅可比矩阵更有效。对于近平方矩阵,jacfwd可能比jacrev具有优势。

再往下,

为了实现 hessian,我们可以使用jacfwd(jacrev(f))jacrev(jacfwd(f))或两者的任何其他组合。但正反转通常是最有效的。这是因为在内部雅可比计算中,我们经常区分一个函数宽雅可比(可能像损失函数:Rⁿ→R),而在外部雅可比计算中,我们用平方雅可比计算来区分一个函数(因为∇ :Rⁿ→Rⁿ),这是前向模式胜出的地方。

由于您的函数看起来像:Rⁿ→R,因此jit(jacfwd(jacrev(fun)))可能是最有效的方法。

至于为什么你不能用grad实现一个 hessian,这是因为grad只为具有标量输出的函数的导数而设计。根据定义,黑森是向量值雅可比的组合,而不是标量梯度的组合。

最新更新