如何在 Tensorflow/Keras 中的 2 层之间创建循环连接?

How to create a recurrent connection between 2 layers in Tensorflow/Keras?

基本上我想做的是采用以下非常简单的前馈图:

然后添加一个循环层,将第二个 Dense 层的输出作为第一个 Dense 层的输入,如下所示。这两种模型显然都是对我的实际用例的简化,但我想我所要求的一般原则对两者都适用。

我想知道在 Tensorflow 甚至 keras 中是否有一种有效的方法来实现这一点,尤其是在 GPU 处理效率方面。虽然我相当有信心我可以在 Tensorflow 中破解一个自定义模型来实现这个 function-wise,但我对这种自定义模型的 GPU 处理效率持悲观态度。因此,如果有人知道实现这些 2 层之间的循环连接 的有效方法,我将不胜感激。感谢您的时间! =)


为了完整起见,这里是创建第一个简单前馈图的代码。我通过图像编辑创建的循环图。

inputs = tf.keras.Input(shape=(128,))

h_1 = tf.keras.layers.Dense(64)(inputs)
h_2 = tf.keras.layers.Dense(32)(h_1)
out = tf.keras.layers.Dense(16)(h_2)

model = tf.keras.Model(inputs, out)

由于我的问题没有收到任何答案,我想分享我想出的解决方案,以防有人通过搜索找到这个问题。

如果您找到或提出更好的解决方案,请告诉我 - 谢谢!

class SimpleModel(tf.keras.Model):
    def __init__(self, input_shape, *args, **kwargs):
        super(SimpleModel, self).__init__(*args, **kwargs)
        # Create node layers
        self.node_1 = tf.keras.layers.InputLayer(input_shape=input_shape)
        self.node_2 = tf.keras.layers.Dense(64, activation='sigmoid')
        self.node_3 = tf.keras.layers.Dense(32, activation='sigmoid')
        self.node_4 = tf.keras.layers.Dense(16, activation='sigmoid')
        self.conn_3_2_recurrent_state = None

        # Create recurrent connection states
        node_1_output_shape = self.node_1.compute_output_shape(input_shape)
        node_2_output_shape = self.node_2.compute_output_shape(node_1_output_shape)
        node_3_output_shape = self.node_3.compute_output_shape(node_2_output_shape)

        self.conn_3_2_recurrent_state = tf.Variable(initial_value=self.node_3(tf.ones(shape=node_2_output_shape)),
                                                    trainable=False,
                                                    validate_shape=False,
                                                    dtype=tf.float32)
        # OR
        # self.conn_3_2_recurrent_state = tf.random.uniform(shape=node_3_output_shape, minval=0.123, maxval=4.56)
        # OR
        # self.conn_3_2_recurrent_state = tf.ones(shape=node_3_output_shape)
        # OR
        # self.conn_3_2_recurrent_state = tf.zeros(shape=node_3_output_shape)

    def call(self, inputs):
        x = self.node_1(inputs)

        #tf.print(self.conn_3_2_recurrent_state)
        #tf.print(self.conn_3_2_recurrent_state.shape)

        x = tf.keras.layers.Concatenate(axis=-1)([x, self.conn_3_2_recurrent_state])
        x = self.node_2(x)
        x = self.node_3(x)

        self.conn_3_2_recurrent_state.assign(x)
        #tf.print(self.conn_3_2_recurrent_state)
        #tf.print(self.conn_3_2_recurrent_state.shape)

        x = self.node_4(x)
        return x


# Demonstrate statefulness of model (uncomment tf prints in model.call())
model = SimpleModel(input_shape=(10, 128))
x = tf.ones(shape=(10, 128))
model(x)
model(x)


# Demonstrate trainability of the recurrent connection TF model
x = tf.random.uniform(shape=(10, 128))
y = tf.ones(shape=(10, 16))

model = SimpleModel(input_shape=(10, 128))
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(x=x, y=y, epochs=100)