在Ray RLlib中使用图形模式会在PPOTFPolicy中调用tf.keras.model.predict()函数时



我正在使用Ray RLlib对PPOTFPolicy进行两次修改来训练PPO代理。

  • 我添加了一个mixin类(比如"Recal"(到";mixins";参数in";build_tf_policy(("。这样,PPOTFPolicy将把我的";重新计算";类中定义的成员函数的访问权限;重新计算";。我的"重新计算";类是tf.keras.Model的一个简单子类
  • 我定义了一个";my_postprocess_fn";用以取代";compute_gae_for_sample_batch"赋予参数"0"的函数;postprocess_fn";在";build_tf_policy((">

;PPOTrainer=build_trainer(…(";函数保持不变。我使用框架=";tf";,并使渴望模式被禁用。

Psuedo代码如下。这是colab的运行版本。

tf.compat.v1.disable_eager_execution()
class Recal:
def __init__(self):
self.recal_model = build_and_compile_keras_model()
def my_postprocess_fn(policy, sample_batch):
with policy.model.graph.as_default():
sample_batch = policy.recal_model.predict(sample_batch)
return compute_gae_for_sample_batch(policy, sample_batch)
PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")
for i in range(1):
result = ppo_trainer.train()

这样";重新计算";类是PPOTFPolicy的基类,并且当PPOTFPolcy的实例被创建时;重新计算";在同一tensorflow图中实例化。但是,当调用my_postprocess_fn((时,它会引发一个错误(见下文(。

tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]

我已经和Ray一起探索了一段时间。所以我想我可以给你一个答案。

Ray使用自己版本的Model类。并且这个类没有tf.keras.Model.prdict方法来获取批量预测。然而,它确实提供了其他选择。

我还没有发现这两个类的输出是否相等。在寻找这个问题的答案的过程中,只有我遇到了你的问题。如果你看到这一点,我很乐意继续对话。:(

最新更新