我想在每一集之后检索数据,我已经阅读了您可以使用的文档stable_baselines3.common.monitor.ResultsWriter
,但我不知道如何在代码中实现它。
import gym
import numpy as np
import Neural_Traffic_Env
import stable_baselines3
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList, StopTrainingOnMaxEpisodes, EveryNTimesteps
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.monitor import Monitor, ResultsWriter
env = gym.make('NeuralTraffic-v1')
env = Monitor(env, filename="Monitor")
eval_callback = EvalCallback(env, best_model_save_path='./logs/best_model', log_path='./logs/', eval_freq=500)
checkpoint_callback = CheckpointCallback(save_freq=100, save_path='./saves/')
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=1000, verbose=1)
callback = CallbackList([callback_max_episodes, checkpoint_callback, eval_callback])
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=1e6, log_interval=1, callback=callback)
model.save("ddpg")
env = model.get_env()
还有没有一个稳定的基线论坛,我也可以直接提出我的问题?
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor
tmp_path = "./tmp/sb3_log/"
# set up logger
new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])
model = PPO('MlpPolicy', env, verbose=1)
model.set_logger(new_logger)