Jax - Debugging NaN-values



大家晚上好,

在过去的6个小时里,我试图调试Jax中看似随机出现的NaN值。我已经缩小了NaN最初来源于损失函数或其梯度的范围。

此处提供可复制错误的最小笔记本https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing

作为Jax的一个用例,这可能也很有趣。当只有有限数量的陀螺仪/加速度计测量可用时,我使用Jax来解决方位估计任务。在这里,四元数运算的有效实现是很好的。

训练循环一开始很好,但最终偏离

Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: nan | Time: 5.463077545166016s
Step 9| Loss: nan | Time: 5.479652643203735s

这可以通过发散梯度来追溯,从下面的片段中可以看出

(loss, _), grads = loss_fn(params, X[0], y[0], rnn.reset_carry(bs=2))
grads["params"]["Dense_0"]["bias"] # shape=(bs, out_features)
DeviceArray([[-0.38666773,         nan, -1.0433975 ,         nan],
[ 0.623061  , -0.20950513,  0.8459796 , -0.42356613]],            dtype=float32)

我的问题是:如何调试这个

启用NaN调试

启用nan调试并没有真正的帮助,因为它只会导致带有许多隐藏痕迹的巨大堆栈。。

from jax.config import config
config.update("jax_debug_nans", True)

任何帮助都将不胜感激!谢谢:(

一些方法(在主文档中有适当的记录(可能会起作用:

  1. 作为一个修补程序,切换到float64就可以了。更多信息请点击:jax.config.update("jax_enable_x64", True)
  2. 渐变剪裁是你所需要的一切(文档(
  3. 有时你可以实现自己的反向运算,例如,当你将两个饱和的函数组合成一个不饱和的函数时,或者在奇点处强制执行值时,这会有所帮助
  4. 通过检查计算图来诊断你的反向探测。通常查找用div标记表示的分区:
from jax import make_jaxpr
# If grad_fn(x) gives you trouble, you can inspect the computation as follows:
grad_fn = jit(value_and_grad(my_forward_prop, argnums=0))
make_jaxpr(grad_fn)(x)

请注意,社区非常活跃,已经并正在添加一些支持来诊断NaNs:

  • ;jax_debug_nans";config标志
  • 更多线程

希望这能有所帮助
Andres

最新更新