TensorFlow模型拟合和train_on_batch之间的差异



我正在构建一个普通的DQN模型来玩OpenAI健身房Cartpole游戏。

但是,在训练步骤中,我将状态作为输入,目标 Q 值作为标签,如果我使用model.fit(x=states, y=target_q),它工作正常,代理最终可以很好地玩游戏,但如果我使用model.train_on_batch(x=states, y=target_q),损失不会减少,模型不会比随机策略更好地玩游戏。

我想知道fittrain_on_batch有什么区别?据我了解,fit调用train_on_batch的批大小为 32 应该没有区别,因为指定批大小等于我输入的实际数据大小没有区别。

如果需要更多上下文信息来回答这个问题,请在此处查看完整代码:https://github.com/ultronify/cartpole-tf

model.fit将训练 1 个或多个 epoch。这意味着它将训练多个批次。 顾名思义,model.train_on_batch只训练一批。

举一个具体的例子,假设你正在 10 张图像上训练一个模型。假设您的批量大小为 2。model.fit将对所有 10 张图像进行训练,因此它将更新 5 次渐变。(您可以指定多个纪元,以便它循环访问您的数据集。model.train_on_batch将执行一次梯度更新,因为您只批量提供模型。如果您的批大小为 2,您将model.train_on_batch提供两张图像。

如果我们假设model.fit在引擎盖下调用model.train_on_batch(尽管我认为并非如此(,那么model.train_on_batch将被多次调用,很可能在一个循环中。这是要解释的伪代码。

def fit(x, y, batch_size, epochs=1):
for epoch in range(epochs):
for batch_x, batch_y in batch(x, y, batch_size):
model.train_on_batch(batch_x, batch_y)

最新更新