此代码中的权重在哪里更新?

Where the weights get updated in this code?

我想在分布式系统中训练一个模型。我在 github 中找到了用于分布式训练的代码,其中工作节点将梯度发送到参数服务器,参数服务器将平均梯度发送到工作节点。但是在 client/worker 边代码中,我无法理解接收到的梯度在哪里更新权重和偏差。

这里是client/worker代码,它从参数服务器接收初始梯度,然后计算损失,梯度并将梯度值再次发送到服务器。

from __future__ import division
from __future__ import print_function

import numpy as np
import sys
import pickle as pickle
import socket

from datetime import datetime
import time

import tensorflow as tf

import cifar10

TCP_IP = 'some IP'
TCP_PORT = 5014

port = 0
port_main = 0
s = 0

FLAGS = tf.app.flags.FLAGS


tf.app.flags.DEFINE_string('train_dir', '/home/ubuntu/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 5000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)


def safe_recv(size, server_socket):
    data = ""
    temp = ""
    data = bytearray()
    recv_size = 0
    while 1:
        try:
            temp = server_socket.recv(size-len(data))
            data.extend(temp)
            recv_size = len(data)
            if recv_size >= size:
                break
        except:
            print("Error")
    data = bytes(data)
    return data


def train():
    """Train CIFAR-10 for a number of steps."""

    g1 = tf.Graph()
    with g1.as_default():
        global_step = tf.Variable(-1, name='global_step',
                                  trainable=False, dtype=tf.int32)
        increment_global_step_op = tf.assign(global_step, global_step+1)

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)
        grads = cifar10.train_part1(loss, global_step)

        only_gradients = [g for g, _ in grads]

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                # log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess:
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            global port
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port_main))
            recv_size = safe_recv(17, s)
            recv_size = pickle.loads(recv_size)
            recv_data = safe_recv(recv_size, s)
            var_vals = pickle.loads(recv_data)
            s.close()
            feed_dict = {}
            i = 0
            for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                feed_dict[v] = var_vals[i]
                i = i+1
            print("Received variable values from ps")
            # Opening the socket and connecting to server
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port))
            while not mon_sess.should_stop():
                gradients, step_val = mon_sess.run(
                    [only_gradients, increment_global_step_op], feed_dict=feed_dict)
                # sending the gradients
                send_data = pickle.dumps(gradients, pickle.HIGHEST_PROTOCOL)
                to_send_size = len(send_data)
                send_size = pickle.dumps(to_send_size, pickle.HIGHEST_PROTOCOL)
                s.sendall(send_size)
                s.sendall(send_data)
                # receiving the variable values
                recv_size = safe_recv(17, s)
                recv_size = pickle.loads(recv_size)
                recv_data = safe_recv(recv_size, s)
                var_vals = pickle.loads(recv_data)

                feed_dict = {}
                i = 0
                for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                    feed_dict[v] = var_vals[i]
                    i = i+1
            s.close()


def main(argv=None):  # pylint: disable=unused-argument
    global port
    global port_main
    global s
    if(len(sys.argv) != 3):
        print("<port> <worker-id> required")
        sys.exit()
    port = int(sys.argv[1]) + int(sys.argv[2])
    port_main = int(sys.argv[1])
    print("Connecting to port ", port)
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    total_start_time = time.time()
    train()
    print("--- %s seconds ---" % (time.time() - total_start_time))


if __name__ == '__main__':
    tf.app.run()

编辑:

这里是 train_part1() 代码:

def train_part1(total_loss, global_step):
  """Train CIFAR-10 model.

  Create an optimizer and apply to all trainable variables. Add moving
  average for all trainable variables.

  Args:
    total_loss: Total loss from loss().
    global_step: Integer Variable counting the number of training steps
      processed.
  Returns:
    train_op: op for training.
  """
  # Variables that affect learning rate.
  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

  # Decay the learning rate exponentially based on the number of steps.
  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)
  tf.summary.scalar('learning_rate', lr)

  # Generate moving averages of all losses and associated summaries.
  loss_averages_op = _add_loss_summaries(total_loss)

  # Compute gradients.
  with tf.control_dependencies([loss_averages_op]):
    opt = tf.train.GradientDescentOptimizer(lr)
    grads = opt.compute_gradients(total_loss)

  return grads

对我来说似乎是行

gradients, step_val = mon_sess.run(
                    [only_gradients, increment_global_step_op], feed_dict=feed_dict)

收到 feed_dict 中变量的新值,将这些值分配给变量,并进行训练步骤,在此期间它只计算和 returns 梯度,稍后将其发送到参数服务器。我希望 cifar10.train_part1(returns only_gradients)取决于变量值并定义更新。

更新: 我查看了代码并改变了主意。不得不 google 并发现 next answer 阐明了正在发生的事情。

渐变实际上并没有隐式应用到这段代码中的任何地方。相反,梯度被发送到参数服务器,参数服务器平均梯度并将它们应用于权重,它 returns 权重到本地工作人员, * 在会话 运行 期间使用收到的权重而不是本地权重feed_dict* 即局部权重实际上从未更新过,实际上根本不重要。关键是 feed_dict 允许重写会话 运行 的任何张量输出并且此代码重写变量。