从ray.tune中提取代理



我一直在使用azure机器学习来训练使用ray.tune的强化学习代理。

我的训练函数如下:

tune.run(
run_or_experiment="PPO",
config={
"env": "Battery",
"num_gpus" : 1,
"num_workers": 13,
"num_cpus_per_worker": 1,
"train_batch_size": 1024,
"num_sgd_iter": 20,
'explore': True,
'exploration_config': {'type': 'StochasticSampling'},
},
stop={'episode_reward_mean': 0.15},
checkpoint_freq = 200,
local_dir = 'second_checkpoints'

)

我如何从检查点提取代理,以便我可以可视化我的健身房环境中的动作,如下所示:

while not done:
action, state, logits = agent.compute_action(obs, state)
obs, reward, done, info = env.step(action)
episode_reward += reward
print('action: ' + str(action) + 'reward: ' + str(reward))

我明白我可以这样写:

analysis = tune.run('PPO",config={"max_iter": 10}, restore=last_ckpt)

但是我不确定如何从存在于tune.run中的代理中提取计算动作(和奖励)。

tune run是用来训练模型的。培训后,您应该有一些检查点文件。这些文件可以加载,然后在你的环境中播放。

agent = ppo.PPOTrainer(config=config, env=env_name)
agent.restore(checkpoint_file)
obs = env.reset()
action = agent.compute_action(obs)
obs, reward, done, info = env.step(action)

最新更新