变分自动编码器损失函数中正则化项的自定义 keras 回调和变化权重(beta)

Custom keras callbacks and changing weight (beta) of regularization term in variational autoencoder loss function

变分自动编码器损失函数是这样的:Loss = Loss_reconstruction + Beta * Loss_kld。我正在尝试有效地实施 Kullback-Liebler Divergence Cyclic Annealing--that is changing the weight of beta dynamically during training. I subclass the tf.keras.callbacks.Callback class as a start, but I don't know how I can update a tf.keras.Model variable from a custom keras callback. Furthermore, I would like to track how the betas change at the end of each training step (on_train_batch_end), and right now I have a list in the callback class, but I know python lists don't play well with TensorFlow. When I fit the model, I get a warning that my on_train_batch_end function is slower than the processing of the batch itself. I think I should use a tf.TensorArray instead of python lists, but then the tf.TensorArray method write cannot use a tf.Variable for the index (i.e., as the number of steps changes, the index in the tf.TensorArray to which a new beta for that step should be written changes)... is there a better way to store value changes? It looks like this github 显示了一个不涉及自定义 tf.keras.Model 并且使用不同类型的 KL 退火的解决方案。下面是一个回调函数和虚拟 VAE。

class CyclicAnnealing(tf.keras.callbacks.Callback):
  """Cyclic annealing from https://arxiv.org/abs/1903.10145
  
  Requires that model tracks training iterations and 
  total number of training iterations. It also requires
  that model has hyperparameter for `M` and `R`.
  """

  def __init__(self, schedule_fxn='sigmoid', **kwargs):
    super().__init__(**kwargs)

    # INEFFICIENT WAY OF LOGGING `betas` AND THE TRAIN STEPS...
    # The `train_iterations` list could be removed because in principle
    # if I have a list of betas, I know that the list of betas is of length
    # (number of samples//batch size) * number of epochs.
    # This is because (number of samples//batch size) * number of epochs is the total number of steps for the model.
    self.betas = []
    self.train_iterations = []

    if schedule_fxn == 'sigmoid':
      self.schedule_fxn = self.sigmoid

    elif schedule_fxn =='linear':
      self.schedule_fxn = self.linear

    else:
      raise ValueError('Invalid arg: `schedule_fxn`')

  def on_epoch_end(self, epoch, logs=None):
    print('\nCurrent anneal weight B =', self.beta)

  def on_train_batch_end(self, batch, logs=None):
    """Computes betas and updates list"""

    # Compute beta
    self.beta = self.beta_tau_cyclic_annealing(self.compute_tau())

    ###################################
    # HOW TO UPDATE BETA IN THE MODEL???
    ###################################

    # Update the lists for logging
    self.betas.append(self.beta)
    self.train_iterations.append(self.model._train_counter))

  def get_annealing_data(self):
    return {'betas': self.betas, 'training_iterations': self.train_iterations}

  def sigmoid(self, x):
    """Monotonic increasing function
    
    :return: tf.constant float32
    """

    return (1/(1+tf.keras.backend.exp(-x)))

  def linear(self, x):
    return x/self.model._R

  def compute_tau(self):
    """Used to determine kld_beta.
    
    :return: tf.constant float32
    """

    t = tf.identity(self.model._train_counter)
    T = self.model._total_training_iterations
    M = self.model._M
    numerator = tf.cast(tf.math.floormod(tf.subtract(t, 1), tf.math.floordiv(T, M)), dtype=tf.float32)
    denominator = tf.cast(tf.math.floordiv(T, M), dtype=tf.float32)
    return tf.math.divide(numerator, denominator)

  def beta_tau_cyclic_annealing(self, tau):
    """Compute change for kld_beta.
    
    :param tau: Increases beta_tau
    :param R: Proportion used to increase Beta w/i cycle.

    :return: tf.constant float32
    """

    R = self.model._R
    if tau <= R:
        return self.schedule_fxn(tau)
    else:
      return tf.constant(1.0)

虚拟 vae:

class VAE(tf.keras.Model):
    def __init__(self, num_samples, batch_size, epochs, features, units, latent_size, kld_beta, M, R, **kwargs):
        """Defines state for model.

        :param num_samples: <class 'int'>
        :param batch_size: <class 'int'>
        :param epochs: <class 'int'>
        :param features: <class 'int'> if input is (n, m), then `features` is the the `m` dimension. This param is used with the decoder.
        :param units: <class 'int'> Number of hidden units.
        :param latent_size: <class 'int'> Dimension of latent space z.
        :param kld_beta: <tf.Variable??> for dynamic weight.
        :param M: <class 'int'> Hyperparameter for cyclic annealing.
        :param R: <class 'float'> Hyperparameter for cyclic annealing.
        """
        super().__init__(**kwargs)

        # NEED TO UPDATE THIS SOMEHOW -- I think it should be a tf.Variable?
        self.kld_beta = kld_beta

        # Hyperparameters for CyclicAnnealing
        self._M = M
        self._R = R
        self._total_training_iterations = (num_samples//batch_size) * epochs    

        # Encoder and Decoder not defined, but typically
        # encoder = inputs -> dense -> dense mu and dense log var -> z
        # while decoder = z -> dense -> reconstructions
        self.encoder = Encoder(units, latent_size)
        self.decoder = Decoder(features)

    def call(self, inputs):
        z, mus, log_vars = self.encoder(inputs)
        reconstructions = self.decoder(z)

        kl_loss = self.compute_kl_loss(mus, log_vars)

        # THE BETA WEIGHT NEEDS TO BE DYNAMIC
        weighted_kl_loss = self.kld_beta * kl_loss
      
        self.add_loss(weighted_kl_loss)

        return reconstructions
        
    def compute_kl_loss(self, mus, log_vars):
         return -0.5 * tf.reduce_mean(1. + log_vars - tf.exp(log_vars) - tf.pow(mus, 2))

关于您的第一个问题:这取决于您打算如何使用优化器(例如 ADAM)更新梯度。使用 Tensorflow/Keras 训练 VAE 时,我通常使用 @tf.function 装饰器来计算我的模型的损失,并基于该更新我的模型参数:

@tf.function
def train_step(self, model, batch, gamma, capacity):
    with tf.GradientTape() as tape:
        x, c = batch
        loss = compute_loss(model, x, c, gamma, capacity)
        tf.print('Total loss: ', loss)

    gradients = tape.gradient(loss, model.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

注意变量 gamma 和容量。它们被定义为影响损失函数的项。我在 x 个时期后更新它们如下:

new_weight = min(tf.keras.backend.get_value(capacity) + (20. / capacity_annealtime), 20.)
tf.keras.backend.set_value(capacity, new_weight)

此时您可以轻松地保存 new_weight 以用于记录目的,或者您可以定义自定义 Tensorflow logger 以登录到文件中。如果你真的想使用数组,你可以简单地定义一个 TF 数组:

this_array = tf.TensorArray(tf.float32, size=0, dynamic=True)

并在 x 步后更新它:

this_array.write(this_array.size(), new_beta_weight)

您还可以使用第二个数组并同时更新它,以记录您的 new_beta_weight 更新的纪元或批次。

最后,损失函数本身是这样的:

def compute_loss(model, x, c, gamma_weight, capacity_weight):

  mean, logvar = model.encode(x, c)

  z = model.reparameterize(mean, logvar)
  reconstruction = model.decode(z, c)

  total_reconstruction_loss = 
  tf.nn.sigmoid_cross_entropy_with_logits(labels=x,                                                                      
  logits=reconstruction)
  
  total_reconstruction_loss = tf.reduce_sum(total_reconstruction_loss, 
   1)

  kl_loss = 1 + logvar - tf.square(mean) - tf.exp(logvar)
  kl_loss = tf.reduce_mean(kl_loss)
  kl_loss *= -0.5

  total_loss = tf.reduce_mean(total_reconstruction_loss * 3 + (
        gamma_weight * tf.abs(kl_loss - capacity_weight)))
  return total_loss

请注意,模型来自 tf.keras.Model 类型。希望这能让您对这个特定主题有一些不同的见解。