如何在稳定的基线3中获得DQN中的Q值



我有一个Box格式的观察空间,但实际上定义为numpy数组。

例如:

Box(low=np.array([0, 0, 0]), high=np.array([15, 10,150]))

现在我想得到单个观测的q_value,但由于观测是Box,所以稳定基线3的代码是:

if isinstance(observation_space, spaces.Box):
return obs.float()

但是,输入观察没有float属性,所以在这种情况下,我如何访问所有操作的q_values?

所以,我想好了如何解决它。如果这也是别人的问题,我会把它发布在这里。

observation = obs.reshape((-1,) + model.observation_space.shape)
observation = obs_as_tensor(observation, device)
with th.no_grad():
q_values = model.q_net(observation)

相关内容

  • 没有找到相关文章

最新更新