我如何去这个错误在开放的AI健身房和stable_baselines3?



我正在学习Python中的强化学习与稳定基线3参考教程由sentdex。当我使用check_env()运行检查代码时,我得到一个错误AssertionError: The observation returned by thereset()method does not match the given observation space。显然,我不知道reset方法中的return有什么问题。

代码如下:

import gym
from gym import spaces
import numpy as np
import cv2
import random
import time
from collections import deque
SNAKE_LEN_GOAL = 30

def collision_with_apple(apple_position, score):
apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
score += 1
return apple_position, score

def collision_with_boundaries(snake_head):
if snake_head[0] >= 500 or snake_head[0] < 0 or snake_head[1] >= 500 or snake_head[1] < 0:
return 1
else:
return 0

def collision_with_self(snake_position):
snake_head = snake_position[0]
if snake_head in snake_position[1:]:
return 1
else:
return 0

class SnekEnv(gym.Env):
def __init__(self):
super(SnekEnv, self).__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL)  # however long we aspire the snake to be
self.action_space = spaces.Discrete(4)
# Example for using image as input (channel-first; channel-last also works):
self.observation_space = spaces.Box(low=-500, high=500,
shape=(5 + SNAKE_LEN_GOAL,), dtype=np.float32)
def step(self, action):
self.prev_actions.append(action)
cv2.imshow('a', self.img)
cv2.waitKey(1)
self.img = np.zeros((500, 500, 3), dtype='uint8')
# Display Apple
cv2.rectangle(self.img, (self.apple_position[0], self.apple_position[1]),
(self.apple_position[0] + 10, self.apple_position[1] + 10), (0, 0, 255), 3)
# Display Snake
for position in self.snake_position:
cv2.rectangle(self.img, (position[0], position[1]), (position[0] + 10, position[1] + 10), (0, 255, 0), 3)
# Takes step after fixed time
t_end = time.time() + 0.05
k = -1
while time.time() < t_end:
if k == -1:
k = cv2.waitKey(1)
else:
continue
button_direction = action
# Change the head position based on the button direction
if button_direction == 1:
self.snake_head[0] += 10
elif button_direction == 0:
self.snake_head[0] -= 10
elif button_direction == 2:
self.snake_head[1] += 10
elif button_direction == 3:
self.snake_head[1] -= 10
# Increase Snake length on eating apple
if self.snake_head == self.apple_position:
self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
self.snake_position.insert(0, list(self.snake_head))
else:
self.snake_position.insert(0, list(self.snake_head))
self.snake_position.pop()
# On collision kill the snake and print the score
if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
font = cv2.FONT_HERSHEY_SIMPLEX
self.img = np.zeros((500, 500, 3), dtype='uint8')
cv2.putText(self.img, 'Your Score is {}'.format(self.score), (140, 250), font, 1, (255, 255, 255), 2,
cv2.LINE_AA)
cv2.imshow('a', self.img)
self.done = True
#self.total_reward = len(self.snake_position) - 3  # default length is 3
#self.reward = self.total_reward - self.prev_reward
#self.prev_reward = self.total_reward
if self.done:
self.reward = -10
else:
self.reward = self.score
head_x = self.snake_head[0]
head_y = self.snake_head[1]
apple_delta_x = self.apple_position[0] - head_x
apple_delta_y = self.apple_position[1] - head_y
snake_length = len(self.snake_position)
self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL)
for _ in range(SNAKE_LEN_GOAL):
self.prev_actions(-1)
# create observation:
observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
observation = np.array(observation)
info = {}
return observation, self.reward, self.done, info
def reset(self):
self.img = np.zeros((500, 500, 3), dtype='uint8')
# Initial Snake and Apple position
self.snake_position = [[250, 250], [240, 250], [230, 250]]
self.apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
self.score = 0
self.prev_button_direction = 1
self.button_direction = 1
self.snake_head = [250, 250]
self.prev_reward = 0
self.done = False
head_x = self.snake_head[0]
head_y = self.snake_head[1]
apple_delta_x = self.apple_position[0] - head_x
apple_delta_y = self.apple_position[1] - head_y
snake_length = len(self.snake_position)
for i in range(SNAKE_LEN_GOAL):
self.prev_actions.append(-1)  # to create history
# create observation:
observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
observation = np.array(observation)
return observation

检查环境。

from stable_baselines3.common.env_checker import check_env
from snake_python_game_Env import SnekEnv

env = SnekEnv()
# It will check your custom environment and output additional warnings if needed
check_env(env)

错误。

Traceback (most recent call last):
File "C:UsersThis PCPycharmProjectspythonProjectsnake_python_game_agent.py", line 7, in <module>
check_env(env)
File "C:UsersThis PCAppDataLocalProgramsPythonPython38libsite-packagesstable_baselines3commonenv_checker.py", line 302, in check_env
_check_returned_values(env, observation_space, action_space)
File "C:UsersThis PCAppDataLocalProgramsPythonPython38libsite-packagesstable_baselines3commonenv_checker.py", line 159, in _check_returned_values
_check_obs(obs, observation_space, "reset")
File "C:UsersThis PCAppDataLocalProgramsPythonPython38libsite-packagesstable_baselines3commonenv_checker.py", line 112, in _check_obs
assert observation_space.contains(
AssertionError: The observation returned by the `reset()` method does not match the given observation space

根据教程的代码应该可以运行,但在我这边却不能。

我认为你应该改变你定义观察空间的行:

self.observation_space = spaces.Box(low=-500, high=500,
shape=(5 + SNAKE_LEN_GOAL,), dtype=int)

在这里,我改变了观察空间将具有的数据类型,在您的情况下,它似乎是整数值数组。当我在本地尝试时,它给出了一个不同的错误:TypeError: 'collections.deque' object is not callable是从环境的阶跃函数抛出的。希望对你有帮助。

相关内容

  • 没有找到相关文章

最新更新