在 keras 模型中执行差分输入以用于损失

Performing Differentiation wrt input within a keras model for use in loss

keras中是否有任何层计算输入的导数?例如,如果输入x,第一层就是f(x),那么下一层的输出应该是f'(x)。这里有多个关于这个主题的问题,但所有问题都涉及模型外导数的计算。本质上,我想创建一个神经网络,其损失函数涉及输入的雅可比矩阵和粗麻布矩阵。

我试过以下方法

import keras.backend as K

def create_model():

    x = keras.Input(shape = (10,))
    layer = Dense(1, activation = "sigmoid")
    output = layer(x)

    jac = K.gradients(output, x)
    
    model = keras.Model(inputs=x, outputs=jac)
    
    return model

model = create_model()
X = np.random.uniform(size = (3, 10))

这是报错tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.

所以我尝试使用那个

def create_model2():
    with tf.GradientTape() as tape:
        x = keras.Input(shape = (10,))
        layer = Dense(1, activation = "sigmoid")
        output = layer(x)

    jac = tape.gradient(output, x)
    
    model = keras.Model(inputs=x, outputs=jac)
    
    return model

model = create_model2()
X = np.random.uniform(size = (3, 10))

但这告诉我 'KerasTensor' object has no attribute '_id'

这两种方法在模型外都可以正常工作。我的最终目标是在损失函数中使用 Jacobian 和 Hessian,因此也欢迎使用其他方法

不确定你到底想做什么,但可以尝试自定义 Kerastf.gradients:

import tensorflow as tf
tf.random.set_seed(111)

class GradientLayer(tf.keras.layers.Layer):
  def __init__(self):
    super(GradientLayer, self).__init__()
    self.dense = tf.keras.layers.Dense(1, activation = "sigmoid")
  
  @tf.function
  def call(self, inputs):
    outputs = self.dense(inputs)
    return tf.gradients(outputs, inputs)


def create_model2():
    gradient_layer = GradientLayer()
    inputs = tf.keras.layers.Input(shape = (10,))
    outputs = gradient_layer(inputs)    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    return model

model = create_model2()
X = tf.random.uniform((3, 10))
print(model(X))
tf.Tensor(
[[-0.07935508 -0.12471244 -0.0702782  -0.06729251  0.14465885 -0.0818079
  -0.08996294  0.07622238  0.11422144 -0.08126545]
 [-0.08666676 -0.13620329 -0.07675356 -0.07349276  0.15798753 -0.08934557
  -0.09825202  0.08324542  0.12474566 -0.08875315]
 [-0.08661086 -0.13611545 -0.07670406 -0.07344536  0.15788564 -0.08928795
  -0.09818865  0.08319173  0.12466521 -0.08869591]], shape=(3, 10), dtype=float32)