我在问自己,下面的代码是只执行一步梯度下降,还是执行整个梯度下降算法?
opt = tf.keras.optimizers.SGD(learning_rate=self.learning_rate)
opt = tf.keras.optimizers.SGD(learning_rate=self.learning_rate)
train = opt.minimize(self.loss, var_list=[self.W1, self.b1, self.W2, self.b2, self.W3, self.b3])
你需要在你决定的梯度下降中做一些步骤。但我不确定opt.minimize(self.loss, var_list=[self.W1, self.b1, self.W2, self.b2, self.W3, self.b3])
是否在做所有的步骤,而不是做梯度下降的一步。为什么我认为它能完成所有步骤?因为在那之后我的损失为零。
tf.keras.optimizers.Optimizer.minimize()
计算梯度并应用它们。因此,这是一个单一的步骤。
在该功能的文档中,您可以阅读:
此方法简单地使用tf.GradientTape计算梯度并调用apply_gradients((。如果要在应用之前处理渐变然后显式调用tf.GradientTape和apply_gradients((,而不是使用此函数。
这也可以从minimize((的实现中看出:
def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
"""Minimize `loss` by updating `var_list`.
This method simply computes gradient using `tf.GradientTape` and calls
`apply_gradients()`. If you want to process the gradient before applying
then call `tf.GradientTape` and `apply_gradients()` explicitly instead
of using this function.
Args:
loss: `Tensor` or callable. If a callable, `loss` should take no arguments
and return the value to minimize. If a `Tensor`, the `tape` argument
must be passed.
var_list: list or tuple of `Variable` objects to update to minimize
`loss`, or a callable returning the list or tuple of `Variable` objects.
Use callable when the variable list would otherwise be incomplete before
`minimize` since the variables are created at the first time `loss` is
called.
grad_loss: (Optional). A `Tensor` holding the gradient computed for
`loss`.
name: (Optional) str. Name for the returned operation.
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
the tape that computed the `loss` must be provided.
Returns:
An `Operation` that updates the variables in `var_list`. The `iterations`
will be automatically increased by 1.
Raises:
ValueError: If some of the variables are not `Variable` objects.
"""
grads_and_vars = self._compute_gradients(
loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
return self.apply_gradients(grads_and_vars, name=name)