在 tf_agents 中使用 BatchedPyEnvironment

Using BatchedPyEnvironment in tf_agents

我正在尝试从 Tensorflow Agents 库中创建一个 SAC 代理示例的批处理环境版本,可以找到原始代码 here。我也在使用自定义环境。

我正在寻求批处理环境设置,以便更好地利用 GPU 资源来加快训练速度。我的理解是,通过将成批的轨迹传递给 GPU,将数据从主机 (CPU) 传递给设备 (GPU) 时产生的开销会更少。

我的自定义环境称为 SacEnv,我尝试创建一个批处理环境,如下所示:

py_envs = [SacEnv() for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tf_env = tf_py_environment.TFPyEnvironment(batched_env)

我希望这将创建一个由 'batch' 非批处理环境组成的批处理环境。但是,当 运行 代码时,我收到以下错误:

ValueError: Cannot assign value to variable ' Accumulator:0': Shape mismatch.The variable shape (1,), and the assigned value shape (32,) are incompatible.

使用堆栈跟踪:

Traceback (most recent call last):
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 370, in <module>
    app.run(main)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 366, in main
    train_eval(FLAGS.root_dir)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 274, in train_eval
    results = metric_utils.eager_compute(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/eval/metric_utils.py", line 163, in eager_compute
    common.function(driver.run)(time_step, policy_state)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 211, in run
    return self._run_fn(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/utils/common.py", line 188, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 238, in _run
    tf.while_loop(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 154, in loop_body
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 154, in <listcomp>
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metric.py", line 93, in __call__
    return self._update_state(*args, **kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metric.py", line 81, in _update_state
    return self.call(*arg, **kwargs)
ValueError: in user code:

    File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metrics.py", line 176, in call  *
        self._return_accumulator.assign(

    ValueError: Cannot assign value to variable ' Accumulator:0': Shape mismatch.The variable shape (1,), and the assigned value shape (32,) are incompatible.

  In call to configurable 'eager_compute' (<function eager_compute at 0x7fa4d6e5e040>)
  In call to configurable 'train_eval' (<function train_eval at 0x7fa4c8622dc0>)

我已经仔细研究了 tf_metric.py 代码以尝试理解错误,但是我没有成功。当我将批量大小 (32) 添加到 AverageReturnMetric 实例的初始值设定项时,相关问题得到解决,这个问题似乎相关。

完整代码为:

# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3

r"""Train and Eval SAC.

All hyperparameters come from the SAC paper
https://arxiv.org/pdf/1812.05905.pdf

To run:

```bash
tensorboard --logdir $HOME/tmp/sac/gym/HalfCheetah-v2/ --port 2223 &

python tf_agents/agents/sac/examples/v2/train_eval.py \
  --root_dir=$HOME/tmp/sac/gym/HalfCheetah-v2/ \
  --alsologtostderr
\```
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from sac_env import SacEnv

import os
import time

from absl import app
from absl import flags
from absl import logging

import gin
from six.moves import range
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.drivers import dynamic_step_driver
#from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.environments import batched_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
from tf_agents.train.utils import strategy_utils


flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')

FLAGS = flags.FLAGS

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

@gin.configurable
def train_eval(
    root_dir,
    env_name='SacEnv',
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    #initial_collect_steps=10000,
    initial_collect_steps=2000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=31250, # 1000000 / 32
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    #batch_size=256,
    batch_size=32,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=50000,
    policy_checkpoint_interval=50000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):


    py_envs = [SacEnv() for _ in range(0, batch_size)]
    batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
    tf_env = tf_py_environment.TFPyEnvironment(batched_env)
    
    eval_py_envs = [SacEnv() for _ in range(0, batch_size)]
    eval_batched_env = batched_py_environment.BatchedPyEnvironment(envs=eval_py_envs)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_batched_env)

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    strategy = strategy_utils.get_strategy(tpu=False, use_gpu=True)

    with strategy.scope():
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network
            .TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=batch_size,
        max_length=replay_buffer_capacity,
        device="/device:GPU:0")
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    if replay_buffer.num_frames() == 0:
      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps '
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    global_step_val = global_step.numpy()
    while global_step_val < num_iterations:
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val,
                     train_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss


def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir)

if __name__ == '__main__':
  flags.mark_flag_as_required('root_dir')
  app.run(main)

为自定义的非批处理环境创建批处理环境的合适方法是什么?我可以分享我的自定义环境,但我认为问题不在于此,因为当使用批量大小为 1 时代码工作正常。

此外,如果您有任何关于在强化学习场景中提高 GPU 利用率的提示,我们将不胜感激。我已经检查了使用 tensorboard-profiler 来分析 GPU 利用率的示例,但似乎这些需要回调和 fit 函数,这似乎不适用于 RL 用例。

事实证明我在初始化 AverageReturnMetricAverageEpisodeLengthMetric 实例时忽略了传递 batch_size