连接 ConvLSTM2D 模型和表格模型的更好方法
Better way to concatenate ConvLSTM2D model and Tabular model
我建立了一个模型,该模型将一个时间序列的 3 个图像以及 5 个数字信息作为输入,并生成该时间序列的下三个图像。
我通过以下方式完成了这项工作:
- 构建用于处理图像的 ConvLSTM2D 模型(与 Keras 文档 here 中列出的示例非常相似)。输入大小=(3x128x128x3)
- 为具有几个密集层的表格数据构建一个简单模型。输入大小=(1,5)
- 连接这两个模型
- 有一个 Conv3D 模型可以生成接下来的 3 个图像
LSTM 模型产生大小为 393216 (3x128x128x8) 的输出。现在我必须将表格模型的输出设置为 49,152,以便我可以在下一层中获得 442368 (3x128x128x9) 的输入大小。因此,这种不必要的 inflation 表格模型的密集层使得原本高效的 LSTM 模型表现得很糟糕。
有没有更好的方法来连接这两个模型?有没有一种方法可以在表格模型的密集层中输出 10?
型号:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
# x = MaxPooling3D()(x)
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(49152, activation="relu")(x_tab)
x_tab = Flatten()(x_tab)
concat = Concatenate()([x, x_tab])
output = Reshape((3,128,128,9))(concat)
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
model = Model([x_input, x_tab_input], output)
model.compile(loss='mae', optimizer='rmsprop')
模型总结:
Model: "functional_3"
______________________________________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
======================================================================================================================================================
input_4 (InputLayer) [(None, None, 128, 128, 3)] 0
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_9 (ConvLSTM2D) (None, None, 128, 128, 32) 40448 input_4[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization) (None, None, 128, 128, 32) 128 conv_lst_m2d_9[0][0]
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_10 (ConvLSTM2D) (None, None, 128, 128, 16) 27712 batch_normalization_9[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization) (None, None, 128, 128, 16) 64 conv_lst_m2d_10[0][0]
______________________________________________________________________________________________________________________________________________________
input_5 (InputLayer) [(None, 5)] 0
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_11 (ConvLSTM2D) (None, None, 128, 128, 8) 6944 batch_normalization_10[0][0]
______________________________________________________________________________________________________________________________________________________
dense (Dense) (None, 100) 600 input_5[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization) (None, None, 128, 128, 8) 32 conv_lst_m2d_11[0][0]
______________________________________________________________________________________________________________________________________________________
dense_1 (Dense) (None, 49152) 4964352 dense[0][0]
______________________________________________________________________________________________________________________________________________________
flatten_3 (Flatten) (None, None) 0 batch_normalization_11[0][0]
______________________________________________________________________________________________________________________________________________________
flatten_4 (Flatten) (None, 49152) 0 dense_1[0][0]
______________________________________________________________________________________________________________________________________________________
concatenate (Concatenate) (None, None) 0 flatten_3[0][0]
flatten_4[0][0]
______________________________________________________________________________________________________________________________________________________
reshape_2 (Reshape) (None, 3, 128, 128, 9) 0 concatenate[0][0]
______________________________________________________________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, 3, 128, 128, 3) 732 reshape_2[0][0]
======================================================================================================================================================
Total params: 5,041,012
Trainable params: 5,040,900
Non-trainable params: 112
______________________________________________________________________________________________________________________________________________________
我同意你的看法,巨大的 Dense
层(具有数百万个参数)可能会阻碍模型的性能。您可以选择以下两种方法之一,而不是 膨胀 带有 Dense
层的表格数据。
选项 1: 平铺 x_tab
张量,使其与您想要的形状相匹配。这可以通过以下步骤实现:
首先,不需要展平ConvLSTM2D
的编码张量:
x_input = Input(shape=(3, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x) # Shape=(None, None, 128, 128, 8)
# Commented: x = Flatten()(x)
其次,您可以使用一层或多层 Dense
处理表格数据。例如:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
# x_tab = Flatten()(x_tab) # Note: Flattening a 2D tensor leaves the tensor unchanged
第三,我们包装张量流操作 tf.tile in a Lambda 层,有效地创建张量的副本 x_tab
以便它匹配所需的形状:
def repeat_tabular(x_tab):
h = x_tab[:, None, None, None, :] # Shape=(bs, 1, 1, 1, dim)
h = tf.tile(h, [1, 3, 128, 128, 1]) # Shape=(bs, 3, 128, 128, dim)
return h
x_tab = Lambda(repeat_tabular)(x_tab)
最后,我们沿着最后一个轴连接 x
和平铺的 x_tab
张量(您也可以考虑沿着第一个轴连接,对应于通道的维度)
concat = Concatenate(axis=-1)([x, x_tab]) # Shape=(3,128,128,8+dim)
output = concat
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
# ...
请注意,从某种意义上说,该解决方案可能有点天真,因为模型没有将图像的输入序列编码为低维表示,从而限制了网络的接受域并可能导致性能下降。
选项 2: 与自动编码器和 U-Net 类似,可能需要将图像序列编码为低维表示,以便丢弃不需要的变化(例如噪声),同时保留有意义的信号(例如需要推断序列的下 3 个图像)。这可以通过以下方式实现:
首先,将输入的图像序列编码为低维二维张量。例如,类似以下内容的内容:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
请注意,最后一个 ConvLSTM2D
没有返回序列。您可能想探索不同的编码器来达到这一点(例如,您也可以在此处使用池化层)。
其次,使用 Dense
层处理表格数据。例如:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
第三,连接前两个流的数据:
concat = Concatenate(axis=-1)([x, x_tab])
第四,使用 Dense
+ Reshape
层将连接的向量投影到一系列低分辨率图像中:
h = Dense(3 * 32 * 32 * 3)(concat)
output = Reshape((3, 32, 32, 3))(h)
output
的形状允许将图像上采样为 (128, 128, 3)
的形状,但它在其他方面是任意的(例如,您可能还想在这里进行实验)。
最后,应用一层或多层 Conv3DTranspose 层以获得所需的输出(例如 3 张形状为 (128, 128, 3)
的图像)。
output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),
strides=(1, 2, 2), padding='same',
activation='relu')(output)
output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),
strides=(1, 2, 2), padding='same',
activation='relu')(output) # Shape=(None, 3, 128, 128, 3)
讨论 转置 卷积层背后的基本原理 here。本质上,Conv3DTranspose
层与普通卷积的方向相反——它允许将低分辨率图像上采样为高分辨率图像。
我建立了一个模型,该模型将一个时间序列的 3 个图像以及 5 个数字信息作为输入,并生成该时间序列的下三个图像。 我通过以下方式完成了这项工作:
- 构建用于处理图像的 ConvLSTM2D 模型(与 Keras 文档 here 中列出的示例非常相似)。输入大小=(3x128x128x3)
- 为具有几个密集层的表格数据构建一个简单模型。输入大小=(1,5)
- 连接这两个模型
- 有一个 Conv3D 模型可以生成接下来的 3 个图像
LSTM 模型产生大小为 393216 (3x128x128x8) 的输出。现在我必须将表格模型的输出设置为 49,152,以便我可以在下一层中获得 442368 (3x128x128x9) 的输入大小。因此,这种不必要的 inflation 表格模型的密集层使得原本高效的 LSTM 模型表现得很糟糕。
有没有更好的方法来连接这两个模型?有没有一种方法可以在表格模型的密集层中输出 10?
型号:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
# x = MaxPooling3D()(x)
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(49152, activation="relu")(x_tab)
x_tab = Flatten()(x_tab)
concat = Concatenate()([x, x_tab])
output = Reshape((3,128,128,9))(concat)
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
model = Model([x_input, x_tab_input], output)
model.compile(loss='mae', optimizer='rmsprop')
模型总结:
Model: "functional_3"
______________________________________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
======================================================================================================================================================
input_4 (InputLayer) [(None, None, 128, 128, 3)] 0
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_9 (ConvLSTM2D) (None, None, 128, 128, 32) 40448 input_4[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization) (None, None, 128, 128, 32) 128 conv_lst_m2d_9[0][0]
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_10 (ConvLSTM2D) (None, None, 128, 128, 16) 27712 batch_normalization_9[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization) (None, None, 128, 128, 16) 64 conv_lst_m2d_10[0][0]
______________________________________________________________________________________________________________________________________________________
input_5 (InputLayer) [(None, 5)] 0
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_11 (ConvLSTM2D) (None, None, 128, 128, 8) 6944 batch_normalization_10[0][0]
______________________________________________________________________________________________________________________________________________________
dense (Dense) (None, 100) 600 input_5[0][0]
______________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization) (None, None, 128, 128, 8) 32 conv_lst_m2d_11[0][0]
______________________________________________________________________________________________________________________________________________________
dense_1 (Dense) (None, 49152) 4964352 dense[0][0]
______________________________________________________________________________________________________________________________________________________
flatten_3 (Flatten) (None, None) 0 batch_normalization_11[0][0]
______________________________________________________________________________________________________________________________________________________
flatten_4 (Flatten) (None, 49152) 0 dense_1[0][0]
______________________________________________________________________________________________________________________________________________________
concatenate (Concatenate) (None, None) 0 flatten_3[0][0]
flatten_4[0][0]
______________________________________________________________________________________________________________________________________________________
reshape_2 (Reshape) (None, 3, 128, 128, 9) 0 concatenate[0][0]
______________________________________________________________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, 3, 128, 128, 3) 732 reshape_2[0][0]
======================================================================================================================================================
Total params: 5,041,012
Trainable params: 5,040,900
Non-trainable params: 112
______________________________________________________________________________________________________________________________________________________
我同意你的看法,巨大的 Dense
层(具有数百万个参数)可能会阻碍模型的性能。您可以选择以下两种方法之一,而不是 膨胀 带有 Dense
层的表格数据。
选项 1: 平铺 x_tab
张量,使其与您想要的形状相匹配。这可以通过以下步骤实现:
首先,不需要展平ConvLSTM2D
的编码张量:
x_input = Input(shape=(3, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x) # Shape=(None, None, 128, 128, 8)
# Commented: x = Flatten()(x)
其次,您可以使用一层或多层 Dense
处理表格数据。例如:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
# x_tab = Flatten()(x_tab) # Note: Flattening a 2D tensor leaves the tensor unchanged
第三,我们包装张量流操作 tf.tile in a Lambda 层,有效地创建张量的副本 x_tab
以便它匹配所需的形状:
def repeat_tabular(x_tab):
h = x_tab[:, None, None, None, :] # Shape=(bs, 1, 1, 1, dim)
h = tf.tile(h, [1, 3, 128, 128, 1]) # Shape=(bs, 3, 128, 128, dim)
return h
x_tab = Lambda(repeat_tabular)(x_tab)
最后,我们沿着最后一个轴连接 x
和平铺的 x_tab
张量(您也可以考虑沿着第一个轴连接,对应于通道的维度)
concat = Concatenate(axis=-1)([x, x_tab]) # Shape=(3,128,128,8+dim)
output = concat
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
# ...
请注意,从某种意义上说,该解决方案可能有点天真,因为模型没有将图像的输入序列编码为低维表示,从而限制了网络的接受域并可能导致性能下降。
选项 2: 与自动编码器和 U-Net 类似,可能需要将图像序列编码为低维表示,以便丢弃不需要的变化(例如噪声),同时保留有意义的信号(例如需要推断序列的下 3 个图像)。这可以通过以下方式实现:
首先,将输入的图像序列编码为低维二维张量。例如,类似以下内容的内容:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
请注意,最后一个 ConvLSTM2D
没有返回序列。您可能想探索不同的编码器来达到这一点(例如,您也可以在此处使用池化层)。
其次,使用 Dense
层处理表格数据。例如:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
第三,连接前两个流的数据:
concat = Concatenate(axis=-1)([x, x_tab])
第四,使用 Dense
+ Reshape
层将连接的向量投影到一系列低分辨率图像中:
h = Dense(3 * 32 * 32 * 3)(concat)
output = Reshape((3, 32, 32, 3))(h)
output
的形状允许将图像上采样为 (128, 128, 3)
的形状,但它在其他方面是任意的(例如,您可能还想在这里进行实验)。
最后,应用一层或多层 Conv3DTranspose 层以获得所需的输出(例如 3 张形状为 (128, 128, 3)
的图像)。
output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),
strides=(1, 2, 2), padding='same',
activation='relu')(output)
output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),
strides=(1, 2, 2), padding='same',
activation='relu')(output) # Shape=(None, 3, 128, 128, 3)
讨论 转置 卷积层背后的基本原理 here。本质上,Conv3DTranspose
层与普通卷积的方向相反——它允许将低分辨率图像上采样为高分辨率图像。