我正在尝试使用TF agent TF agent DQN教程来训练强化学习agent。在我的应用程序中,我有1个操作,包含9个可能的离散值(标记为从0到8(。以下是env.action_spec()
的输出
BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(8, dtype=int64))
我想得到概率向量,它包含由训练策略计算的所有操作,并在其他应用程序环境中进行进一步处理。但是,策略只返回具有单个值的log_probability
,而不是所有操作的向量。有没有办法得到概率向量?
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent
q_net = q_network.QNetwork(
env.observation_spec(),
env.action_spec(),
fc_layer_params=(32,)
)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
my_agent = dqn_agent.DqnAgent(
env.time_step_spec(),
env.action_spec(),
q_network=q_net,
epsilon_greedy=epsilon,
optimizer=optimizer,
emit_log_probability=True,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
my_agent.initialize()
... # training
tf_policy_saver = policy_saver.PolicySaver(my_agent.policy)
tf_policy_saver.save('./policy_dir/')
# making decision using the trained policy
action_step = my_agent.policy.action(time_step)
在dqn_agent.DqnAgent()
DQNAgent中,我设置了emit_log_probability=True
,它应该定义Whether policies emit log probabilities or not.
但是,当我运行action_step = my_agent.policy.action(time_step)
时,它会返回
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1], dtype=int64)>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>))
我还试着运行action_distribution = saved_policy.distribution(time_step)
,它返回
PolicyStep(action=<tfp.distributions.DeterministicWithLogProbCT 'Deterministic' batch_shape=[1] event_shape=[] dtype=int64>, state=(), info=PolicyInfo(log_probability=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>))
如果在TF.Agent中没有这样的API可用,有没有办法得到这样的概率向量?谢谢
后续问题:
如果我理解正确的话,深度Q网络应该得到state
的输入,并从状态中输出每个动作的Q值。我可以将这个Q值向量传递到softmax函数中,并计算相应的概率向量。事实上,我已经用我自己定制的DQN脚本(没有TF代理(进行了这样的计算。那么问题来了:如何从TF Agent返回Q值向量?
在TF-Agent框架中实现这一点的唯一方法是调用Policy.distribution()
方法,而不是操作方法。这将返回根据网络的Q值计算出的原始分布。emit_log_probability=True
仅影响Policy.action()
返回的PolicyStep
命名元组的info
属性。请注意,此分布可能会受到您传递的操作约束的影响(如果您传递了(;从而非法动作将被标记为具有0概率(即使原始Q值可能已经很高(。
此外,如果您希望看到实际的Q值,而不是它们生成的分布,那么如果不直接对代理附带的Q网络(也连接到代理生成的Policy
对象(采取行动,恐怕无法做到这一点。如果你想知道如何正确地调用Q网络,我建议你在这里看看QPolicy._distribution()
方法是如何做到的。
请注意,这些都不能使用预先实现的驱动程序来完成。您必须显式地构建自己的集合循环,或者实现自己的Driver对象(这基本上是等效的(。