tf-agent, QNetwork => DqnAgent w/ tfa.optimizers.CyclicalLearningRate



是否有一种简单的本地方法可以在DqnAgent上实现tfa.optimizers.CyclicalLearningRate w/QNetwork?

尽量避免写我自己的DqnAgent。

我想更好的问题可能是,在DqnAgent上实现回调的正确方法是什么?

从您链接的教程中,他们设置优化器的部分是

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()

因此,您可以用您希望使用的任何优化器来替换优化器。基于类似的文档

optimizer = tf.keras.optimizers.Adam(learning_rate=tfa.optimizers.CyclicalLearningRate)

应该可以工作,排除任何潜在的兼容性问题,因为他们在教程中使用了tf1.0adam。

相关内容

  • 没有找到相关文章

最新更新