如何在 Tf-agents 中传递自定义环境的批量大小

How to pass the batchsize for a custom environment in Tf-agents

我正在使用 tf-agents 库构建上下文强盗。 为此,我正在构建自定义环境。
我正在创建一个 banditpyenvironment 并将其包装在 TFpyenvironment 中。

tfpyenvironment 自动添加批量大小维度(在观察规范中)。我需要在 _observe 和 _apply_Action 方法中考虑这个批量大小维度。由于根据批量大小,我应该提供所需的(批量大小)观察次数(用于观察),并且根据批量大小,我应该采用批量大小的操作数并提供奖励(用于应用操作)。

我找不到关于如何告诉 tfenvironment 批量大小的单个示例,而不让自动将 1 添加到第一个维度。有人可以澄清一下吗

 def __init__(self, batch_size):

    self.batchsize=batch_size
    observation_spec = BoundedTensorSpec(
    (2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
    action_spec = BoundedTensorSpec(
        shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')


    super(SampleEnvironment, self).__init__(observation_spec, action_spec)

  def _observe(self):
    batch=[]
    for i in range(self.batchsize):
        each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
        batch.append(each)
    self.observation=np.array(batch)
    print("in observe",self.observation)
    return np.array(self.observation)

当我尝试以某种方式在上面的观察方法中考虑批量大小时(对批量大小使用 for 循环),tfenvironment 再次将 1 添加到第一个维度作为批量大小。 有没有办法自动告诉环境批处理是 3,而不是自动添加 1。同时,我如何在重放缓冲区和代理中考虑这个批处理大小

这可以使用 BatchedPyEnvironment class 完成,如下例所示。上面的 bandit 环境看起来是一个非批处理的环境。

下面的SampleEnvironment是问题中显示的banditpyenvironment

batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)