有人能解释一下下面的行为吗?是bug吗?
from jax import grad
import jax.numpy as jnp
x = jnp.ones(2)
grad(lambda v: jnp.linalg.norm(v-v))(x) # returns DeviceArray([nan, nan], dtype=float32)
grad(lambda v: jnp.linalg.norm(0))(x) # returns DeviceArray([0., 0.], dtype=float32)
我试着在网上查找错误,但没有找到任何相关的。
我还浏览了https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
计算grad(lambda v: jnp.linalg.norm(v-v))(x)
时,函数大致如下:
f(x) = sqrt[(x - x)^2]
因此,用链式法则求值,导数是
df/dx = (x - x) / sqrt[(x - x)^2]
,当你插入任何有限的x
时,它的值为
0 / sqrt(0)
没有定义,在浮点运算中用NaN
表示。
计算grad(lambda v: jnp.linalg.norm(0))(x)
时,函数大致如下:
g(x) = sqrt[0.0^2]
,因为它不依赖于x
,导数就是
dg/dx = 0.0
这样回答你的问题了吗?