Flux的自定义渐变,而不是使用Zygote A.D



我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动微分。然而,我仍然希望能够利用Flux中的不同优化器,而不必依赖Zygote进行差异化。以下是我的一些代码片段。

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = Flux.Params([b, c, U, W])
opt = ADAM(0.01)

然后我有一个函数来计算我的模型参数θ的分析梯度。

function gradients(x) # x = one input data point or a batch of input data points
# stuff to calculate gradients of each parameter
# returns gradients of each parameter

然后,我希望能够做以下事情。

grads = gradients(x)
update!(opt, θ, grads)

我的问题是:我的gradient(x)函数需要返回什么形式/类型才能执行update!(opt, θ, grads),我该如何做到这一点?

如果不使用Params,则grads只需要是渐变即可。唯一的要求是θgrads的大小相同。

例如,map((x, g) -> update!(opt, x, g), θ, grads),其中θ == [b, c, U, W]grads = [gradients(b), gradients(c), gradients(U), gradients(W)](不确定gradients期望什么作为输入(。

更新:但要回答您最初的问题,gradients需要返回此处找到的Grads对象:https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88

所以类似的东西

# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b

但不使用Params可能更容易调试。唯一的问题是,没有一个update!可以处理一系列参数,但你可以很容易地定义自己的:

function Flux.Optimise.update!(opt, xs::Tuple, gs)
for (x, g) in zip(xs, gs)
update!(opt, x, g)
end
end
# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = (b, c, U, W)
opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)

更新2:

另一种选择是仍然使用Zygote来获取梯度,以便它自动为您设置Grads对象,但使用自定义伴随,以便它使用您的分析函数来计算伴随。假设您的ML模型被定义为名为f的函数,那么f(x)将为输入x返回模型的输出。我们还假设gradients(x)返回分析梯度w.r.t.x,就像您在问题中提到的那样。然后,以下代码仍将使用Zygote的AD,它将正确填充Grads对象,但它将使用您为函数f:计算梯度的定义

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = Flux.Params([b, c, U, W])
f(x) = # define your model
gradients(x) = # define your analytical gradient
# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)
opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)

请注意,我在上面使用了Flux.mse作为损失示例。这种方法的一个缺点是Zygote的gradient函数需要标量输出。如果您的模型被传递到某个将输出标量误差值的损失中,那么@adjoint是最好的方法。这将适用于您正在进行标准ML的情况,唯一的变化是您希望Zygote使用您的函数解析计算f的梯度。

如果你正在做一些更复杂的事情,并且不能使用Zygote.gradient,那么第一种方法(不使用Params(是最合适的。Params的存在实际上只是为了与Flux的旧AD向后兼容,所以如果可能的话,最好避免它。

最新更新