为什么jax.Grad (lambda v: jnp. linalgn .norm(v-v))(jnp.ones(2))



有人能解释一下下面的行为吗?是bug吗?

from jax import grad
import jax.numpy as jnp
x = jnp.ones(2)
grad(lambda v: jnp.linalg.norm(v-v))(x) # returns DeviceArray([nan, nan], dtype=float32)
grad(lambda v: jnp.linalg.norm(0))(x) # returns DeviceArray([0., 0.], dtype=float32)

我试着在网上查找错误,但没有找到任何相关的。

我还浏览了https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

计算grad(lambda v: jnp.linalg.norm(v-v))(x)时,函数大致如下:

f(x) = sqrt[(x - x)^2]

因此,用链式法则求值,导数是

df/dx = (x - x) / sqrt[(x - x)^2]

,当你插入任何有限的x时,它的值为

0 / sqrt(0)

没有定义,在浮点运算中用NaN表示。

计算grad(lambda v: jnp.linalg.norm(0))(x)时,函数大致如下:

g(x) = sqrt[0.0^2]

,因为它不依赖于x,导数就是

dg/dx = 0.0

这样回答你的问题了吗?

最新更新