JAX:传递字典而不是arg数来标识变量以进行自动区分



我想使用JAX作为梯度下降的载体;但是,我有相当多的参数,并且更愿意将它们作为字典传递f(func, dict)而不是f(func, x1, ...xn)

所以不用

# https://www.kaggle.com/code/grez911/tutorial-efficient-gradient-descent-with-jax/notebook
def J(X, w, b, y):
"""Cost function for a linear regression. A forward pass of our model.
Args:
X: a features matrix.
w: weights (a column vector).
b: a bias.
y: a target vector.
Returns:
scalar: a cost of this solution.    
"""
y_hat = X.dot(w) + b # Predict values.
return ((y_hat - y)**2).mean() # Return cost.
for i in range(100):
w -= learning_rate * grad(J, argnums=1)(X, w, b, y)
b -= learning_rate * grad(J, argnums=2)(X, w, b, y)

更像是

for i in range(100):
w -= learning_rate * grad(J, arg_key='w')(arg_dict)
b -= learning_rate * grad(J, arg_key='b')(arg_dict)

这可能吗?

编辑:

这是我目前的工作解决方案:

# A features matrix.
X = np.array([
[4., 7.],
[1., 8.],
[-5., -6.],
[3., -1.],
[0., 9.]
])
# A target column vector.
y = np.array([
[37.],
[24.],
[-34.], 
[16.],
[21.]
])
learning_rate = 0.01
w = np.zeros((2, 1))
b = 0.
import jax.numpy as np
from jax import grad
def J(X, w, b, y):
"""Cost function for a linear regression. A forward pass of our model.
Args:
X: a features matrix.
w: weights (a column vector).
b: a bias.
y: a target vector.
Returns:
scalar: a cost of this solution.    
"""
y_hat = X.dot(w) + b # Predict values.
return ((y_hat - y)**2).mean() # Return cost.
# Define your function arguments as a dictionary
arg_dict = {
'X': X,
'w': w,
'b': b,
'y': y
}
idx_dict = {idx:name for idx,name in enumerate(arg_dict.keys())}
arg_arr = [arg_dict[idx_dict[idx]] for idx in range(len(arg_dict))]
for i in range(100):
for idx, name in idx_dict.items():
var = arg_dict[idx_dict[idx]]
var -= learning_rate * grad(J, argnums=idx) (*arg_arr)

要点是,现在我不需要为每个需要自动分化的变量写grad(…)。

在JAX中目前不支持按名称指定autodiff argnums,尽管这个想法正在讨论中:https://github.com/google/jax/issues/10614

在实现之前,有很多方法可以将argnames自动转换为argnums,使用inspect.signature为您的函数(一些示例在链接的问题中),但总的来说,为您的特定函数手动进行映射可能更简单。

最新更新