a和b的导数,使用算法微分

  • 本文关键字:算法 python numpy jax
  • 更新时间 :
  • 英文 :


我的任务是使用jax查找a和b的导数功能

现在,我来这里的原因是因为我对Python了解不够,而对于有问题的课程,我们也没有被认为是Python。

任务是:

return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
and dfb is the partial derivative of f by b

现在,我可以用正常的方式:

def function(a, b):
dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
return (dfa, dfb)

但我不熟悉算法微分,使用我们给出的例子,我尝试过:

def foo():
x = (2/b)*sym.cos(a)
y = sym.exp(-sym.Pow(a/b,2))
return (x*y)
def f_partial_derviatives_algo():
return jax.grad(foo)

但我得到了这个错误:

无法解压缩不可迭代的函数对象

如果有人能帮助我理解如何做这样的事情,将不胜感激

JAX和sympy不兼容。你应该使用一个或另一个,不要试图将两者结合起来。

如果您想使用JAX计算这个函数在某个值上的偏导数,您可以写这样的东西:

import jax.numpy as jnp
from jax import grad
def f(a, b):
return (2 / b) * jnp.cos(a) * jnp.exp(- a ** 2 / b ** 2)
df_da = grad(f, argnums=0)
df_db = grad(f, argnums=1)
print(df_da(1.0, 1.0), df_db(1.0, 1.0))
# -1.4141841 0.3975322

最新更新