TensorFlow 模型拟合与 train_on_batch 之间的差异

Difference between TensorFlow model fit and train_on_batch

我正在构建一个普通的 DQN 模型来玩 OpenAI gym Cartpole 游戏。

然而,在我输入状态作为输入并以目标 Q 值作为标签的训练步骤中,如果我使用 model.fit(x=states, y=target_q),它工作正常并且代理最终可以很好地玩游戏,但是如果我使用 model.train_on_batch(x=states, y=target_q),损失不会减少,模型也不会比随机策略更好地玩游戏。

请问fittrain_on_batch有什么区别?据我了解,fit 调用 train_on_batch 并在引擎盖下批量大小为 32,这应该没有区别,因为将批量大小指定为等于我输入的实际数据大小没有区别。

如果需要更多上下文信息来回答这个问题,完整的代码在这里:https://github.com/ultronify/cartpole-tf

model.fit 将训练 1 个或多个 epoch。这意味着它将训练多个批次。 model.train_on_batch,顾名思义,只训练一批

举一个具体的例子,假设你正在用 10 张图片训练一个模型。假设您的批量大小为 2。model.fit 将对所有 10 张图像进行训练,因此它将更新梯度 5 次。 (您可以指定多个时期,因此它会遍历您的数据集。)model.train_on_batch 将执行一次梯度更新,因为您只批量提供模型。如果您的批量大小为 2,您将提供 model.train_on_batch 两张图片。

如果我们假设 model.fit 在幕后调用 model.train_on_batch(虽然我不认为它会调用),那么 model.train_on_batch 将被调用多次,很可能在一个环形。这里用伪代码来解释。

def fit(x, y, batch_size, epochs=1):
    for epoch in range(epochs):
        for batch_x, batch_y in batch(x, y, batch_size):
            model.train_on_batch(batch_x, batch_y)