带有ParallelEnv的PettingZoo和Stable-Baselines3的问题



我在使用PettingZoo编写的自定义ParallelEnv时遇到了问题。我使用SuperSuit的ss.pettingzoo_env_to_vec_env_v1(env)作为包装器来矢量化环境,并使其与Stable-Baseline3一起工作,并在此处进行了记录。

您可以在附件中找到代码中最相关部分的摘要:

from typing import Optional
from gym import spaces
import random
import numpy as np
from pettingzoo import ParallelEnv
from pettingzoo.utils.conversions import parallel_wrapper_fn
import supersuit as ss
from gym.utils import EzPickle, seeding

def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
#env_ = ss.concat_vec_envs_v1(env_, 1)
return env_

petting_zoo = env

class parallel_env(ParallelEnv, EzPickle):
metadata = {'render_modes': ['ansi'], "name": "PlayerEnv-Multi-v0"}
def __init__(self, n_agents: int = 20, new_step_api: bool = True) -> None:
EzPickle.__init__(
self,
n_agents,
new_step_api
)
self._episode_ended = False
self.n_agents = n_agents
self.possible_agents = [
f"player_{idx}" for idx in range(n_agents)]
self.agents = self.possible_agents[:]
self.agent_name_mapping = dict(
zip(self.possible_agents, list(range(len(self.possible_agents))))
)
self.observation_spaces = spaces.Dict(
{agent: spaces.Box(shape=(len(self.agents),),
dtype=np.float64, low=0.0, high=1.0) for agent in self.possible_agents}
)
self.action_spaces = spaces.Dict(
{agent: spaces.Discrete(4) for agent in self.possible_agents}
)
self.current_step = 0
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
def observation_space(self, agent):
return self.observation_spaces[agent]
def action_space(self, agent):
return self.action_spaces[agent]
def __calculate_observation(self, agent_id: int) -> np.ndarray:
return self.observation_space(agent_id).sample()
def __calculate_observations(self) -> np.ndarray:
observations = {
agent: self.__calculate_observation(
agent_id=agent)
for agent in self.agents
}
return observations
def observe(self, agent):
return self.__calculate_observation(agent_id=agent)
def step(self, actions):
if self._episode_ended:
return self.reset()
observations = self.__calculate_observations()
rewards = random.sample(range(100), self.n_agents)
self.current_step += 1
self._episode_ended = self.current_step >= 100
infos = {agent: {} for agent in self.agents}
dones = {agent: self._episode_ended for agent in self.agents}
rewards = {
self.agents[i]: rewards[i]
for i in range(len(self.agents))
}
if self._episode_ended:
self.agents = {}  # To satisfy `set(par_env.agents) == live_agents`
return observations, rewards, dones, infos
def reset(self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,):
self.agents = self.possible_agents[:]
self._episode_ended = False
self.current_step = 0
observations = self.__calculate_observations()
return observations
def render(self, mode="human"):
# TODO: IMPLEMENT
print("TO BE IMPLEMENTED")
def close(self):
pass

不幸的是,当我尝试用以下主要程序进行测试时:

from stable_baselines3 import DQN, PPO
from stable_baselines3.common.env_checker import check_env
from dummy_env import dummy
from pettingzoo.test import parallel_api_test

if __name__ == '__main__':
# Testing the parallel algorithm alone
env_parallel = dummy.parallel_env()
parallel_api_test(env_parallel)  # This works!
# Testing the environment with the wrapper
env = dummy.petting_zoo()
# ERROR: AssertionError: The observation returned by the `reset()` method does not match the given observation space 
check_env(env)  
# Model initialization
model = PPO("MlpPolicy", env, verbose=1)

# ERROR: ValueError: could not broadcast input array from shape (20,20) into shape (20,)
model.learn(total_timesteps=10_000)

我得到以下错误:

AssertionError: The observation returned by the `reset()` method does not match the given observation space

如果我跳过check_env(),我会得到以下内容:

ValueError: could not broadcast input array from shape (20,20) into shape (20,)

似乎ss.pettingzoo_env_to_vec_env_v1(env)能够将并行环境拆分为多个矢量化环境,但对于reset()函数来说却不行。

有人知道如何解决这个问题吗?

请找到Github存储库来重现问题。

由于我在SuperSuit存储库的问题部分进行了讨论,我能够发布问题的解决方案。感谢jjshots!

首先,有必要拥有最新的SuperSuit版本。为了做到这一点,我需要使用此处的说明安装Stable-Baseline3,使其与gym 0.24+一起工作。

之后,以问题中的代码为例,有必要替换

def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
#env_ = ss.concat_vec_envs_v1(env_, 1)
return env_

带有

def env(**kwargs):
env_ = parallel_env(**kwargs)
env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
env_ = ss.concat_vec_envs_v1(env_, 1, base_class="stable_baselines3")
return env_

结果是:

  • 结果1:离开带有check_env(env)的行,我得到一个错误AssertionError: Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py
  • 结果2:使用check_env(env)移除线路,代理成功开始训练

最后,我认为论点base_class="stable_baselines3"起了作用。只有check_env上的小问题还有待报道,但我认为如果训练有效,它可以被认为是微不足道的。

您应该仔细检查PettingZoo中的reset((函数。它将返回None,而不是像GYM 那样的观察结果

相关内容

  • 没有找到相关文章

最新更新