稳定基线3运行时错误:mat1和mat2必须具有相同的数据类型



我正试图在Stable Baselines3中使用自定义环境来实现SAC,但标题中不断出现错误。该错误发生在任何偏离策略的算法中,而不仅仅是SAC。

追溯:

File "<MY PROJECT PATH>srcmain.py", line 70, in <module>
main()
File "<MY PROJECT PATH>srcmain.py", line 66, in main
model.learn(total_timesteps=timesteps, reset_num_timesteps=False, tb_log_name=f"sac_{num_cars}_cars")
File "<MY PROJECT PATH>venvlibsite-packagesstable_baselines3sacsac.py", line 309, in learn
return super().learn(
File "<MY PROJECT PATH>venvlibsite-packagesstable_baselines3commonoff_policy_algorithm.py", line 375, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "<MY PROJECT PATH>venvlibsite-packagesstable_baselines3sacsac.py", line 256, in train
current_q_values = self.critic(replay_data.observations, replay_data.actions)
File "<MY PROJECT PATH>venvlibsite-packagestorchnnmodulesmodule.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "<MY PROJECT PATH>venvlibsite-packagesstable_baselines3commonpolicies.py", line 885, in forward
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "<MY PROJECT PATH>venvlibsite-packagesstable_baselines3commonpolicies.py", line 885, in <genexpr>
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "<MY PROJECT PATH>venvlibsite-packagestorchnnmodulesmodule.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "<MY PROJECT PATH>venvlibsite-packagestorchnnmodulescontainer.py", line 204, in forward
input = module(input)
File "<MY PROJECT PATH>venvlibsite-packagestorchnnmodulesmodule.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "<MY PROJECT PATH>venvlibsite-packagestorchnnmoduleslinear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype

行动和观察空间:

self.action_space = Box(low=-1., high=1., shape=(2,), dtype=np.float)
self.observation_space = Box(
np.array(
[-np.inf] * (9 * 40) + [-np.inf] * 3 + [-np.inf] * 3 + [-np.inf] * 3
+ [0.] + [0.] + [0.] + [-1.] + [0.] * 4 + [0.] * 4 + [0.] * 4,
dtype=np.float
),
np.array(
[np.inf] * (9 * 40) + [np.inf] * 3 + [np.inf] * 3 + [np.inf] * 3
+ [np.inf] + [1.] + [1.] + [1.] + [1.] * 4 + [np.inf] * 4 + [np.inf] * 4,
dtype=np.float
),
dtype=np.float
)

观测结果在步骤和重置方法中返回为浮点数的numpy数组。

是不是我遗漏了什么东西导致了这个错误?如果我使用健身房附带的环境,比如钟摆,效果很好,这就是为什么我认为我的定制环境有问题。

提前感谢您的帮助,如果需要更多信息,请告诉我。

将输入更改为float32,默认情况下加载程序将类型设置为float64。

inputs = inputs.to(torch.float32)

最新更新