Tensorflow 与 PyTorch 中的计算梯度

Computing gradient in Tensorflow vs PyTorch

我正在尝试为一个简单线性模型的损失计算梯度。但是,我遇到的问题是,在使用 TensorFlow 时,梯度计算为 'none'。为什么会发生这种情况以及如何使用 TensorFlow 计算梯度?

import numpy as np
import tensorflow as tf

inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')

targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')

inputs = tf.convert_to_tensor(inputs)
targets = tf.convert_to_tensor(targets)

w = tf.random.normal(shape=(2, 3))
b = tf.random.normal(shape=(2,))
print(w, b)

def model(x):
  return tf.matmul(x, w, transpose_b = True) + b

def mse(t1, t2):
  diff = t1-t2
  return tf.reduce_sum(diff * diff) / tf.cast(tf.size(diff), 'float32')

with tf.GradientTape() as tape:
  pred = model(inputs)
  loss = mse(pred, targets)

print(tape.gradient(loss, [w, b]))

这是使用 PyTorch 的工作代码。梯度按预期计算。

import torch

inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')

targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')

inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)

w = torch.randn(2, 3, requires_grad = True)
b = torch.randn(2, requires_grad = True)

def model(x):
  return x @ w.t() + b

def mse(t1, t2):
  diff = t1 - t2
  return torch.sum(diff * diff) / diff.numel()

pred = model(inputs)
loss = mse(pred, targets)
loss.backward()

print(w.grad)
print(b.grad)

您的代码不起作用,因为在 tensorflow 中,仅计算 tf.Variable 的梯度。创建图层时,TF 会自动将其权重和偏差标记为变量(除非您指定 trainable=False)。

因此,为了让您的代码正常工作,您需要做的就是用 tf.Variable

包裹 wb
w = tf.Variable(tf.random.normal(shape=(2, 3)), name='w')
b = tf.Variable(tf.random.normal(shape=(2,)), name='b')

使用这些行来定义您的权重和偏差,您将在最终打印中获得实际值。