从表 'uniform_table' 中在平坦索引 4 处接收到不兼容的张量



我正在尝试将TensorFlow代理教程适应自定义环境。它不是很复杂,旨在教我如何工作。游戏基本上是一个21x21的网格,代理可以通过四处走动收集代币来获得奖励。我可以验证环境、代理和重放缓冲区,但是当我尝试训练模型时,我得到一条错误消息(见底部)。欢迎大家多多指教!

代理类为:

import numpy as np
import random
from IPython.display import clear_output
import time

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
class cGame (py_environment.PyEnvironment):
def __init__(self):
self.xdim = 21
self.ydim = 21
self.mmap = np.array([[0]*self.xdim]*self.ydim)
self._turnNumber = 0
self.playerPos = {"x":1, "y":1}
self.totalScore = 0
self.reward = 0.0
self.input = 0
self.addRewardEveryNTurns = 4
self.addBombEveryNTurns = 3
self._episode_ended = False

## player = 13
## bomb   = 14

self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=3, name='action')
self._observation_spec = array_spec.BoundedArraySpec(shape = (441,),  minimum=np.array([-1]*441), maximum = np.array([20]*441), dtype=np.int32, name='observation')  #(self.xdim, self.ydim)  , self.mmap.shape,  minimum = -1, maximum = 10
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def addMapReward(self):
dx = random.randint(1, self.xdim-2)
dy = random.randint(1, self.ydim-2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = random.randint(1, 9)
return True

def addBombToMap(self):
dx = random.randint(1, self.xdim-2)
dy = random.randint(1, self.ydim-2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = 14
return True

def _reset (self):
self.mmap = np.array([[0]*self.xdim]*self.ydim)
for y in range(self.ydim):
self.mmap[y][0] = -1
self.mmap[y][self.ydim-1] = -1
for x in range(self.xdim):
self.mmap[0][x] = -1
self.mmap[self.ydim-1][x] = -1

self.playerPos["x"] = random.randint(1, self.xdim-2)
self.playerPos["y"] = random.randint(1, self.ydim-2)
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13

for z in range(10):
## place 10 targets
self.addMapReward()
for z in range(5):
## place 5 bombs
## bomb   = 14
self.addBombToMap()
self._turnNumber = 0
self._episode_ended = False
#return ts.restart (self.mmap)
dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
return (dap)

def render(self, mapToRender):
mapToRender.reshape(21,21)
for y  in range(self.ydim):
o =""
for x in range(self.xdim):
if mapToRender[y][x]==-1:
o=o+"#"
elif mapToRender[y][x]>0 and mapToRender[y][x]<10:
o=o+str(mapToRender[y][x])
elif mapToRender[y][x] == 13:
o=o+"@"
elif mapToRender[y][x] == 14:
o=o+"*"
else:
o=o+" "
print (o)
print ('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
return True

def getInput(self):
self.input = 0
i = input()
if i == 'w' or i == '0':
print ('going N')
self.input = 1
if i == 's' or i == '1':
print ('going S')
self.input = 2
if i == 'a' or i == '2':
print ('going W')
self.input = 3
if i == 'd' or i == '3':
print ('going E')
self.input = 4
if i == 'x':
self.input = 5
return self.input

def processMove(self):

self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.reward = 0
if self.input == 0:
self.playerPos["y"] -=1
if self.input == 1:
self.playerPos["y"] +=1
if self.input == 2:
self.playerPos["x"] -=1
if self.input == 3:
self.playerPos["x"] +=1

cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]

if  cloc == -1 or cloc ==14:
self.totalScore = 0
self.reward = -99

if cloc >0 and cloc < 10:
self.totalScore += cloc
self.reward = cloc
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
self.render(self.mmap)

def runTurn(self):
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()

self.getInput()
self.processMove()
self._turnNumber +=1
if self.reward == -99:
self._turnNumber +=1
self._reset()
self.totalScore = 0
self.render(self.mmap)
return (self.reward)

def _step (self, action):

if self._episode_ended == True:
return self._reset() 

clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
## make sure action does produce exceed range
#if action > 5 or action <1:
#    action =0
self.input = action  ## value 1 to 4
self.processMove()
self._turnNumber +=1

if self.reward == -99:
self._turnNumber +=1
self._episode_ended = True
#self._reset()
self.totalScore = 0
self.render(self.mmap)
return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward)
else:
return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward) #, discount = 1.0

def run (self):
self._reset()
self.render(self.mmap)
while (True):
self.runTurn()
if self.input == 5:
return ("EXIT on input x ")
env = cGame()
我想用来训练模型的类是:

from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
import reverb
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.drivers import py_driver
from tf_agents.environments import BatchedPyEnvironment

class mTrainer:
def __init__ (self):

self.train_env = tf_py_environment.TFPyEnvironment(cGame())
self.eval_env  = tf_py_environment.TFPyEnvironment(cGame())

self.num_iterations = 20000 # @param {type:"integer"}
self.initial_collect_steps = 100  # @param {type:"integer"}
self.collect_steps_per_iteration = 100 # @param {type:"integer"}
self.replay_buffer_max_length = 100000  # @param {type:"integer"}
self.batch_size = 64  # @param {type:"integer"}
self.learning_rate = 1e-3  # @param {type:"number"}
self.log_interval = 200  # @param {type:"integer"}
self.num_eval_episodes = 10  # @param {type:"integer"}
self.eval_interval = 1000  # @param {type:"integer"}


def createAgent(self):
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in', distribution='truncated_normal'))
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03, maxval=0.03),
bias_initializer=tf.keras.initializers.Constant(-0.2))

self.q_net = sequential.Sequential(dense_layers + [q_values_layer])

optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
#rain_step_counter = tf.Variable(0)
self.agent = dqn_agent.DqnAgent(
time_step_spec = self.train_env.time_step_spec(),
action_spec = self.train_env.action_spec(),
q_network=self.q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=tf.Variable(0))
self.agent.initialize()

self.eval_policy = self.agent.policy
self.collect_policy = self.agent.collect_policy
self.random_policy = random_tf_policy.RandomTFPolicy(self.train_env.time_step_spec(),self.train_env.action_spec())
return True
def compute_avg_return(self, environment, policy, num_episodes=10):
#mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
print ('average return :', avg_return.numpy()[0])
return avg_return.numpy()[0]
def create_replaybuffer(self):
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(self.agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
table = reverb.Table(table_name,
max_size=self.replay_buffer_max_length,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
self.agent.collect_data_spec,
table_name=table_name,
sequence_length=2,
local_server=reverb_server)
self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
self.replay_buffer.py_client,
table_name,
sequence_length=2)

self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,sample_batch_size=self.batch_size,num_steps=2).prefetch(3)
self.iterator = iter(self.dataset)
def testReplayBuffer(self):
py_driver.PyDriver(
self.train_env,
py_tf_eager_policy.PyTFEagerPolicy(
self.random_policy, 
use_tf_function=True),
[self.rb_observer],
max_steps=self.initial_collect_steps).run(self.train_env.reset())        

def trainAgent(self):

print (self.collect_policy)
# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
self.train_env, 
py_tf_eager_policy.PyTFEagerPolicy(
self.agent.collect_policy,
batch_time_steps=False,
use_tf_function=True),
[self.rb_observer],
max_steps=self.collect_steps_per_iteration)

# Reset the environment.
time_step = self.train_env.reset()

for _ in range(self.num_iterations):
# Collect a few steps and save to the replay buffer.
time_step, _ = collect_driver.run(time_step)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(self.iterator)
train_loss = agent.train(experience).loss
step = agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
if step % eval_interval == 0:
avg_return = self.compute_avg_return(self.eval_env, agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
self.returns.append(avg_return)




def run(self):
self.createAgent()
#self.compute_avg_return(self.train_env,self.eval_policy)
self.create_replaybuffer()
#self.testReplayBuffer()
self.trainAgent()
return True
mT = mTrainer()
mT.run()

它产生这个错误信息:

InvalidArgumentError:从表'uniform_table'的平面化索引4处收到不兼容的张量。规格有(dtype, shape):(int32,[?])。张量有(dtype, shape): (int32,[2,1])。表签名:0:Tensor<名称:'key', dtype: uint64, shape: []>, 1: Tensor<名称:'probability', dtype: double, shape: []>, 2: Tensor<名称:'table_size', dtype: int64, shape: []>, 3: Tensor<名称:'priority', dtype: double, shape: []>, 4: Tensor<名称:'step_type/step_type', dtype: int32, shape: [?]>, 5: Tensor<名称:'observation/observation', dtype: int32, shape: [?,441]>, 6: Tensor<名称:'action/action', dtype: int32, shape: [?]>, 7: Tensor<名称:'next_step_type/step_type', dtype: int32, shape: [?]>, 8: Tensor, 9: Tensor[Op: IteratorGetNext]

我遇到了类似的问题,原因是,您正在使用tensorflow环境作为PyDriver的参数来收集数据。Tensorflow环境为它生成的所有张量增加了一个批处理维度,因此,每个生成的time_step将有一个额外的维度,其值将为1。

现在,当你从回放缓冲区中检索数据时,每个time_step将有一个额外的维度,它与智能体的训练函数期望的数据不兼容,因此出现错误。

这里需要使用python环境来收集正确维度的数据。此外,现在您不必使用batch_time_steps = False

我不知道如何用tensorflow环境收集正确尺寸的数据,所以我修改了你的代码,允许使用python环境收集数据,现在应该运行。

PS -你发布的代码中有一些小错误(例如使用log_interval而不是self.log_interval等)。

<代理类/strong>'

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
from IPython.display import clear_output
import time

import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

class cGame(py_environment.PyEnvironment):
def __init__(self):
self.xdim = 21
self.ydim = 21
self.mmap = np.array([[0] * self.xdim] * self.ydim)
self._turnNumber = 0
self.playerPos = {"x": 1, "y": 1}
self.totalScore = 0
self.reward = 0.0
self.input = 0
self.addRewardEveryNTurns = 4
self.addBombEveryNTurns = 3
self._episode_ended = False
## player = 13
## bomb   = 14
self._action_spec = array_spec.BoundedArraySpec(shape=(),
dtype=np.int32,
minimum=0, maximum=3,
name='action')
self._observation_spec = array_spec.BoundedArraySpec(shape=(441,),
minimum=np.array(
[-1] * 441),
maximum=np.array(
[20] * 441),
dtype=np.int32,
name='observation')  # (self.xdim, self.ydim)  , self.mmap.shape,  minimum = -1, maximum = 10
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def addMapReward(self):
dx = random.randint(1, self.xdim - 2)
dy = random.randint(1, self.ydim - 2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = random.randint(1, 9)
return True
def addBombToMap(self):
dx = random.randint(1, self.xdim - 2)
dy = random.randint(1, self.ydim - 2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = 14
return True
def _reset(self):
self.mmap = np.array([[0] * self.xdim] * self.ydim)
for y in range(self.ydim):
self.mmap[y][0] = -1
self.mmap[y][self.ydim - 1] = -1
for x in range(self.xdim):
self.mmap[0][x] = -1
self.mmap[self.ydim - 1][x] = -1
self.playerPos["x"] = random.randint(1, self.xdim - 2)
self.playerPos["y"] = random.randint(1, self.ydim - 2)
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
for z in range(10):
## place 10 targets
self.addMapReward()
for z in range(5):
## place 5 bombs
## bomb   = 14
self.addBombToMap()
self._turnNumber = 0
self._episode_ended = False
# return ts.restart (self.mmap)
dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
return (dap)
def render(self, mapToRender):
mapToRender.reshape(21, 21)
for y in range(self.ydim):
o = ""
for x in range(self.xdim):
if mapToRender[y][x] == -1:
o = o + "#"
elif mapToRender[y][x] > 0 and mapToRender[y][x] < 10:
o = o + str(mapToRender[y][x])
elif mapToRender[y][x] == 13:
o = o + "@"
elif mapToRender[y][x] == 14:
o = o + "*"
else:
o = o + " "
print(o)
print('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
return True
def getInput(self):
self.input = 0
i = input()
if i == 'w' or i == '0':
print('going N')
self.input = 1
if i == 's' or i == '1':
print('going S')
self.input = 2
if i == 'a' or i == '2':
print('going W')
self.input = 3
if i == 'd' or i == '3':
print('going E')
self.input = 4
if i == 'x':
self.input = 5
return self.input
def processMove(self):
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.reward = 0
if self.input == 0:
self.playerPos["y"] -= 1
if self.input == 1:
self.playerPos["y"] += 1
if self.input == 2:
self.playerPos["x"] -= 1
if self.input == 3:
self.playerPos["x"] += 1
cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]
if cloc == -1 or cloc == 14:
self.totalScore = 0
self.reward = -99
if cloc > 0 and cloc < 10:
self.totalScore += cloc
self.reward = cloc
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
self.render(self.mmap)
def runTurn(self):
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
self.getInput()
self.processMove()
self._turnNumber += 1
if self.reward == -99:
self._turnNumber += 1
self._reset()
self.totalScore = 0
self.render(self.mmap)
return (self.reward)
def _step(self, action):
if self._episode_ended == True:
return self._reset()
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
## make sure action does produce exceed range
# if action > 5 or action <1:
#    action =0
self.input = action  ## value 1 to 4
self.processMove()
self._turnNumber += 1
if self.reward == -99:
self._turnNumber += 1
self._episode_ended = True
# self._reset()
self.totalScore = 0
self.render(self.mmap)
return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(),
reward=self.reward)
else:
return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(),
reward=self.reward)  # , discount = 1.0
def run(self):
self._reset()
self.render(self.mmap)
while (True):
self.runTurn()
if self.input == 5:
return ("EXIT on input x ")

env = cGame()

驱动代码'

from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
import reverb
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.drivers import py_driver
from tf_agents.environments import BatchedPyEnvironment


class mTrainer:
def __init__(self):

self.returns = None
self.train_env = tf_py_environment.TFPyEnvironment(cGame())
self.eval_env = tf_py_environment.TFPyEnvironment(cGame())

self.num_iterations = 20000  # @param {type:"integer"}
self.initial_collect_steps = 100  # @param {type:"integer"}
self.collect_steps_per_iteration = 100  # @param {type:"integer"}
self.replay_buffer_max_length = 100000  # @param {type:"integer"}
self.batch_size = 64  # @param {type:"integer"}
self.learning_rate = 1e-3  # @param {type:"number"}
self.log_interval = 200  # @param {type:"integer"}
self.num_eval_episodes = 10  # @param {type:"integer"}
self.eval_interval = 1000  # @param {type:"integer"}

def createAgent(self):
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in', distribution='truncated_normal'))

dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03, maxval=0.03),
bias_initializer=tf.keras.initializers.Constant(-0.2))

self.q_net = sequential.Sequential(dense_layers + [q_values_layer])

optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# rain_step_counter = tf.Variable(0)

self.agent = dqn_agent.DqnAgent(
time_step_spec=self.train_env.time_step_spec(),
action_spec=self.train_env.action_spec(),
q_network=self.q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=tf.Variable(0))

self.agent.initialize()

self.eval_policy = self.agent.policy
self.collect_policy = self.agent.collect_policy
self.random_policy = random_tf_policy.RandomTFPolicy(
self.train_env.time_step_spec(), self.train_env.action_spec())
return True

def compute_avg_return(self, environment, policy, num_episodes=10):
# mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
print('average return :', avg_return.numpy()[0])
return avg_return.numpy()[0]

def create_replaybuffer(self):

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
self.agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature)

table = reverb.Table(table_name,
max_size=self.replay_buffer_max_length,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
self.agent.collect_data_spec,
table_name=table_name,
sequence_length=2,
local_server=reverb_server)

self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
self.replay_buffer.py_client,
table_name,
sequence_length=2)

self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,
sample_batch_size=self.batch_size,
num_steps=2).prefetch(3)
self.iterator = iter(self.dataset)

def testReplayBuffer(self):
py_env = cGame()
py_driver.PyDriver(
py_env,
py_tf_eager_policy.PyTFEagerPolicy(
self.random_policy,
use_tf_function=True),
[self.rb_observer],
max_steps=self.initial_collect_steps).run(self.train_env.reset())

def trainAgent(self):

self.returns = list()
print(self.collect_policy)
py_env = cGame()
# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
py_env, # CHANGE 1
py_tf_eager_policy.PyTFEagerPolicy(
self.agent.collect_policy,
# batch_time_steps=False, # CHANGE 2
use_tf_function=True),
[self.rb_observer],
max_steps=self.collect_steps_per_iteration)

# Reset the environment.
# time_step = self.train_env.reset()
time_step = py_env.reset()
for _ in range(self.num_iterations):

# Collect a few steps and save to the replay buffer.
time_step, _ = collect_driver.run(time_step)

# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(self.iterator)
train_loss = self.agent.train(experience).loss

step = self.agent.train_step_counter.numpy()

if step % self.log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))

if step % self.eval_interval == 0:
avg_return = self.compute_avg_return(self.eval_env,
self.agent.policy,
self.num_eval_episodes)
print(
'step = {0}: Average Return = {1}'.format(step, avg_return))
self.returns.append(avg_return)

def run(self):
self.createAgent()
# self.compute_avg_return(self.train_env,self.eval_policy)
self.create_replaybuffer()
# self.testReplayBuffer()
self.trainAgent()
return True

if __name__ == '__main__':
mT = mTrainer()
mT.run()

最新更新