如何在Tf代理中传递自定义环境的批处理大小



我正在使用tf代理库来构建上下文土匪。为此,我正在构建一个自定义环境
我正在创建一个banditpyenvironment并将其封装在TFpyenvironmental中。

tfpyenvironment会自动添加批次大小维度(在观察规范中(。我需要在_observe和_apply_Action方法中说明此批量大小维度。由于根据批量大小,我应该提供所需的(批量大小(观察次数(用于观察(,也根据批量大小的不同,我应该采取批量大小的行动,并提供奖励(用于应用行动(。

我找不到一个关于如何告诉tfenvironment批量大小的例子,而不允许在第一个维度上自动添加1。有人能澄清吗

def __init__(self, batch_size):
self.batchsize=batch_size
observation_spec = BoundedTensorSpec(
(2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
action_spec = BoundedTensorSpec(
shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')

super(SampleEnvironment, self).__init__(observation_spec, action_spec)
def _observe(self):
batch=[]
for i in range(self.batchsize):
each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
batch.append(each)
self.observation=np.array(batch)
print("in observe",self.observation)
return np.array(self.observation)

当我试图在上面的观察方法中以某种方式解释batchsize时(使用for循环表示batchsize(,tfenvironment再次在第一个维度上添加1作为batchsize。有没有一种方法可以自动告诉环境批次是3,而不是自动添加1。同时,我该如何在重放缓冲区和代理中解释这个批量大小

这可以使用BatchedPyEnvironment类来完成,如下例所示。从上面看,土匪环境是一个非批处理环境。

下面的SampleEnvironment是问题中显示的banditpyenvironment

batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)

相关内容

  • 没有找到相关文章

最新更新