我有一个Box格式的观察空间,但实际上定义为numpy数组。
例如:
Box(low=np.array([0, 0, 0]), high=np.array([15, 10,150]))
现在我想得到单个观测的q_value,但由于观测是Box,所以稳定基线3的代码是:
if isinstance(observation_space, spaces.Box):
return obs.float()
但是,输入观察没有float属性,所以在这种情况下,我如何访问所有操作的q_values?
所以,我想好了如何解决它。如果这也是别人的问题,我会把它发布在这里。
observation = obs.reshape((-1,) + model.observation_space.shape)
observation = obs_as_tensor(observation, device)
with th.no_grad():
q_values = model.q_net(observation)