为什么 model.get_weights() 是空的 Tensorflow Bug?

Why model.get_weights() is empty Is Tensorflow Bug?

我是 tring 工具 MAML.I 有问题,所以我写了一个简单的版本来显示我的困惑。 如果你使用'optimizer.apply_gradients'更新梯度,它可以通过'model.get_weights()'获得模型权重。但是如果你自己更新梯度,它只会通过'model.get_weights()'获得空列表。

import tensorflow as tf
from tensorflow.keras import layers, activations, losses, Model, optimizers, models
import numpy as np


class MAMLmodel(Model):
    def __init__(self):
        super().__init__()

        self.Dense1 = layers.Dense(2, input_shape=(3, ))
        self.Dense2 = layers.Dense(1)

    def forward(self, inputs):
        x = self.Dense1(inputs)
        x = self.Dense2(x)

        return x

def compute_loss(y_true, y_pred):
    return losses.mean_squared_error(y_true, y_pred)

x1 = [[[1], [1], [1]],
      [[1], [1], [1]],
      [[1], [1], [1]]]

y1 = [[[0], [0], [0]],
      [[0], [0], [0]],
      [[0], [0], [0]]]
x1 = tf.convert_to_tensor(x1)
y1 = tf.convert_to_tensor(y1) 

inner_train_step = 1
batch_size = 3
lr_inner = 0.001

model = MAMLmodel()
inner_optimizer = optimizers.Adam()

for i in range(batch_size):
    # If inner_train_step is 2 or bigger, the gradient is empty list.
    for inner_step in range(inner_train_step):
        with tf.GradientTape() as support_tape:
            support_tape.watch(model.trainable_variables)
            y_pred = model.forward(x1[i])
            support_loss = compute_loss(y1[i], y_pred)

        gradients = support_tape.gradient(support_loss, model.trainable_variables)
        # inner_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        k = 0
        for j in range(len(model.layers)):
            model.layers[j].kernel = tf.subtract(model.layers[j].kernel, tf.multiply(lr_inner, gradients[k]))
            model.layers[j].bias = tf.subtract(model.layers[j].bias, tf.multiply(lr_inner, gradients[k + 1]))
            k += 2

    # If you use 'optimizer.apply_gradients' update gradient,it can print weights.
    # But if you update gradient by yourself,it just print empty list.
    print(model.get_weights())

我找不到我的代码问题,所以我认为是tensorflow bug.Please帮帮我!!!因为错误我睡不着。

这不是 tensorflow 错误 :) 您正在使用基本张量更新模型的 Variables,因此在第二次迭代中,当您调用 .gradient(support_loss, model.trainable_variables) 时,您的模型实际上没有任何可训练的变量了。 像这样修改您的代码以使用操作变量的方法:

import tensorflow as tf
from tensorflow.keras import layers, activations, losses, Model, optimizers, models
import numpy as np


class MAMLmodel(Model):
    def __init__(self):
        super().__init__()

        self.Dense1 = layers.Dense(2, input_shape=(3, ))
        self.Dense2 = layers.Dense(1)

    def forward(self, inputs):
        x = self.Dense1(inputs)
        x = self.Dense2(x)

        return x

def compute_loss(y_true, y_pred):
    return losses.mean_squared_error(y_true, y_pred)

x1 = [[[1], [1], [1]],
      [[1], [1], [1]],
      [[1], [1], [1]]]

y1 = [[[0], [0], [0]],
      [[0], [0], [0]],
      [[0], [0], [0]]]
x1 = tf.convert_to_tensor(x1)
y1 = tf.convert_to_tensor(y1) 

inner_train_step = 2
batch_size = 3
lr_inner = 0.001

model = MAMLmodel()
inner_optimizer = optimizers.Adam()

for i in range(batch_size):
    # If inner_train_step is 2 or bigger, the gradient is empty list.
    for inner_step in range(inner_train_step):
        with tf.GradientTape() as support_tape:
            support_tape.watch(model.trainable_variables)
            y_pred = model.forward(x1[i])
            support_loss = compute_loss(y1[i], y_pred)

        gradients = support_tape.gradient(support_loss, model.trainable_variables)
        # inner_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        print(f'Number of computed gradients: {len(gradients)}')
        k = 0
        for j in range(len(model.layers)):
            model.layers[j].kernel.assign_sub(tf.multiply(lr_inner, gradients[k]))
            model.layers[j].bias.assign_sub(tf.multiply(lr_inner, gradients[k + 1]))
            k += 2

    # If you use 'optimizer.apply_gradients' update gradient,it can print weights.
    # But if you update gradient by yourself,it just print empty list.
    print(f'Get weights: {model.get_weights()}')