我一直在尝试在tensorflow中使用tf-agents构建一个rl代理。我在一个定制的环境中遇到了这个问题,但使用了一个官方的tf协作示例再现了它。每当我尝试使用QRnnNetwork作为DqnAgent的网络时,就会出现这个问题。代理可以很好地使用常规的qnetwork,但是在使用qrnn时,会对policy_state_spec进行重塑。我该如何补救呢?
这是policy_state_spec转换成的形状,但原始形状是()
ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])
q_net = q_rnn_network.QRnnNetwork(
train_env.observation_spec(),
train_env.action_spec(),
lstm_size=(16,),
)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
collect_policy.action(time_step)
我得到这个错误:
TypeError: policy_state and policy_state_spec structures do not match:
()
vs.
ListWrapper([., .])
我进入代码,似乎对于RNN,在action(time_step, policy_state, seed)
方法中,您需要在前一步中提供策略的状态,如文档所述:
policy_state:一个张量,或一个嵌套的字典、列表或张量元组,表示先前的policy_state。https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/GreedyPolicy行动
错误:
TypeError: policy_state and policy_state_spec structures do not match:
()
vs.
ListWrapper([., .])
试图说的是,你应该提供RNN的内部状态的action
方法。我在文档中找到了一个例子:
https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/TFPolicy example_usage
显示的代码(截至2021年8月8日)如下:
env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy
policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()
while not time_step.is_last():
policy_step = policy.action(time_step, policy_state)
time_step = env.step(policy_step.action)
policy_state = policy_step.state
# policy_step.info may contain side info for logging, such as action log
# probabilities.
如果你以这种方式实现你的代码,它可能会工作!