在使用批处理数据集训练网络时,我应该如何跟踪总损失?

How should I keep track of total loss while training a network with a batched dataset?

我正在尝试通过将梯度应用于优化器来训练鉴别器网络。但是,当我使用 tf.GradientTape 查找损失 w.r.t 训练变量的梯度时,返回 None。这是训练循环:

def train_step():
  #Generate noisy seeds
  noise = tf.random.normal([BATCH_SIZE, noise_dim])
  with tf.GradientTape() as disc_tape:
    pattern = generator(noise)
    pattern = tf.reshape(tensor=pattern, shape=(28,28,1))
    dataset = get_data_set(pattern)
    disc_loss = tf.Variable(shape=(1,2), initial_value=[[0,0]], dtype=tf.float32)
    disc_tape.watch(disc_loss)
    for batch in dataset:
        disc_loss.assign_add(discriminator(batch, training=True))

  disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

代码说明

生成器网络从噪声中生成 'pattern'。然后,我通过对张量应用各种卷积从该模式生成数据集。返回的数据集是分批处理的,因此我遍历数据集并通过将这批损失与总损失相加来跟踪鉴别器的损失。

我所知道的

tf.GradientTape returns None 当两个变量之间没有图形连接时。但是损失和可训练变量之间不存在图形联系吗?我相信我的错误与我如何跟踪 disc_loss tf.Variable

中的损失有关

我的问题

如何在遍历批处理数据集时跟踪损失,以便以后可以使用它来计算梯度?

这里的基本答案是 tf.Variable 的 assign_add 函数不可微分,因此无法计算变量 disc_loss 和判别器可训练变量之间的梯度。

在这个非常具体的案例中,答案是

disc_loss = disc_loss + discriminator(batch, training=True)

在以后遇到类似问题的时候,一定要检查梯度带观察时使用的所有操作都是可微的。

This link 有一个可微分和不可微分的张量流操作列表。我发现它非常有用。