分布式 Tensorflow:同步训练无限期停顿

Distributed Tensorflow: Synchronous training stalls indefinitely

我有一个 ps 任务服务器和两个工作任务服务器的分布式设置。每个 运行 宁 CPU。我已经 运行 下面的例子是异步的,但它不能同步工作。我不确定我是否对代码做错了什么:

import math
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")

# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                           "Directory for storing mnist data")
tf.app.flags.DEFINE_integer("batch_size", 3, "Training batch size")

FLAGS = tf.app.flags.FLAGS

IMAGE_PIXELS = 28

steps = 1000

def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  tf.logging.set_verbosity(tf.logging.DEBUG)
  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

      with tf.name_scope('Input'):
        x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="X")
        y_ = tf.placeholder(tf.float32, [None, 10], name="LABELS")

      W = tf.Variable(tf.zeros([IMAGE_PIXELS * IMAGE_PIXELS, 10]), name="W")
      b = tf.Variable(tf.zeros([10]), name="B")
      y = tf.matmul(x, W) + b
      y = tf.identity(y, name="Y")

      with tf.name_scope('CrossEntropy'):
        cross_entropy = tf.reduce_mean(
          tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

      global_step = tf.Variable(0, name="STEP")

      with tf.name_scope('Train'):
        opt = tf.train.GradientDescentOptimizer(0.5)
        opt = tf.train.SyncReplicasOptimizer(opt, 
                                replicas_to_aggregate=2,
                                total_num_replicas=2, 
                                name="SyncReplicasOptimizer")
        train_step = opt.minimize(cross_entropy, global_step=global_step)

      with tf.name_scope('Accuracy'):
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()

#      init_op = tf.initialize_all_variables()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process.
    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir="/tmp/train_logs",
                             init_op=init_op,
                             summary_op=summary_op,
                             saver=saver,
                             global_step=global_step,
                             save_model_secs=600)

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=True,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target, config=config) as sess:
      # Loop until the supervisor shuts down or 1000000 steps have completed.
      writer = tf.summary.FileWriter("~/tensorboard_data", sess.graph)
      step = 0
      while not sv.should_stop() and step < steps:
        print("Starting step %d" % step)
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.

        old_step = step

        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
        train_feed = {x: batch_xs, y_: batch_ys}

        _, step = sess.run([train_step, global_step], feed_dict=train_feed)

#        if step % 2 == 0: 
        print ("Done step %d, next step %d\n" % (old_step, step))

      # Test trained model
      print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

    # Ask for all the services to stop.
    sv.stop()

if __name__ == "__main__":
  tf.app.run()

ps 任务打印如下:

I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> TF2:2222, 1 -> TF0:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222

虽然工人们打印了类似的东西,然后是一些信息:

I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> TF1:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> TF2:2222, 1 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222
INFO:tensorflow:SyncReplicasV2: replicas_to_aggregate=2; total_num_replicas=2
[...]
I tensorflow/core/common_runtime/simple_placer.cc:841] Train/gradients/CrossEntropy/Mean_grad/Prod_1: (Prod)/job:worker/replica:0/task:1/cpu:0
: /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub_2/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat_1/axis: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat_1/values_0: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Slice_1/size: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub_1/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank_2: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat/axis: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat/values_0: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Slice/size: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank_1: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank: (Const): /job:worker/replica:0/task:1/cpu:0
zeros_1: (Const): /job:worker/replica:0/task:1/cpu:0
GradientDescent/value: (Const): /job:ps/replica:0/task:0/cpu:0
Fill/dims: (Const): /job:ps/replica:0/task:0/cpu:0
zeros: (Const): /job:worker/replica:0/task:1/cpu:0
Input/LABELS: (Placeholder): /job:worker/replica:0/task:1/cpu:0
Input/X: (Placeholder): /job:worker/replica:0/task:1/cpu:0
init_all_tables: (NoOp): /job:ps/replica:0/task:0/cpu:0
group_deps/NoOp: (NoOp): /job:ps/replica:0/task:0/cpu:0
report_uninitialized_variables/boolean_mask/strided_slice_1: (StridedSlice): /job:ps/replica:0/task:0/cpu:0
report_uninitialized_variables/boolean_mask/strided_slice: (StridedSlice): /job:ps/replica:0/task:0/cpu:0
[...]
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Slice_1/size: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Sub_1/y: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank_2: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/concat/axis: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/concat/values_0: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Slice/size: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Sub/y: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank_1: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] zeros_1: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] GradientDescent/value: (Const)/job:ps/replica:0/task:0/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Fill/dims: (Const)/job:ps/replica:0/task:0/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] zeros: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Input/LABELS: (Placeholder)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Input/X: (Placeholder)/job:worker/replica:0/task:1/cpu:0

此时没有发生其他事情。我为 SyncReplicasOptimizer 尝试了不同的配置,但似乎没有任何效果。

任何帮助将不胜感激!

编辑:从命令行使用的命令。对于 ps 服务器和工作人员(工作人员不同 task_index):

python filename.py --ps_hosts=server1:2222 --worker_hosts=server2:2222,server3:2222 --job_name=ps --task_index=0
python filename.py --ps_hosts=server1:2222 --worker_hosts=server2:2222,server3:2222 --job_name=worker --task_index=0

在查看其他同步分布式 tensorflow 示例时,我发现了一些使代码工作的 tensorflow 片段。具体来说(在 train_step 之后):

if (FLAGS.task_index == 0): # is chief?
    # Initial token and chief queue runners required by the sync_replicas mode
    chief_queue_runner = opt.get_chief_queue_runner()
    init_tokens_op = opt.get_init_tokens_op()

和(在循环之前的会话内部):

if (FLAGS.task_index == 0): # is chief?
    # Chief worker will start the chief queue runner and call the init op
    print("Starting chief queue runner and running init_tokens_op")
    sv.start_queue_runners(sess, [chief_queue_runner])
    sess.run(init_tokens_op)

因此,仅使用 SyncReplicaOptimizer 包装优化器是不够的,还需要创建和使用 queue_runner 和 init_tokens_op。 我不确定为什么会这样,但我希望这对其他人有帮助。