使用 conv1D 时输入数据和训练数据之间的维度不匹配

Dimension mismatch between input data and trained data when using conv1D

我曾尝试使用 Conv1D 构建我的第一个 CNN,因为我处理时间序列数据。我的目标是对 1501 形状的 input_data 进行压缩。 x_train 形状是 (550, 1501),我增加了它的尺寸以适应模型。

然而,编译器抱怨:

ValueError: A target array with shape (550, 1501, 1) was passed for an output of shape (None, 1500, 1) while using as loss mean_squared_error. This loss expects targets to have the same shape as the output.

这是代码

import numpy as np
from tensorflow.keras.layers import Input,Dense, Conv1D, MaxPooling1D, UpSampling1D, Flatten, Input
from tensorflow.keras import optimizers, Model
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K

#(1,128,1)
input_data = Input(shape=(1501,1))
fil_ord = 3
# Eecode
encode  = Conv1D(2000,  fil_ord, activation='relu', padding='same')input_data) 
encode = MaxPooling1D( 2 )(encode)
encode = Conv1D(750,   fil_ord, activation='relu', padding='same')(encode)

# Decode

decode  = Conv1D(750,  fil_ord, activation='relu', padding='same')(encode)
decode = UpSampling1D( 2)(decode)
decode = Conv1D(1,   fil_ord, activation='sigmoid', padding='same')(decode)


model = Model(input_data, decode)


model.summary()

from numpy import zeros, newaxis
x_train1=x_train[:,:,None]

batch_size = 128
epochs = 10
# Optimizer
sgd = optimizers.Adam(lr=0.001)

# compile
model.compile(loss='mse', optimizer=sgd)
# train
history = model.fit(x_train1, x_train1, batch_size=batch_size, epochs=epochs, verbose=2,shuffle=True)

model.summary() 输出:

错误出在 axis=1decode 输出维度,即 15001501 的目标 x_train1 维度不同。

这是由于这条链 max-poolingupsampling 操作: 1501 -> 750 -> 1500 其中 MaxPooling1D 在下采样时忽略一个附加元素,并在 axis=1 处输出维度 750 的特征,这些特征不会从 UpSampling1D 的上采样操作中恢复.

因此,目标 (x_train1) 和预测 (decode) 输出的形状不同,因此我们无法计算损失。

可用于解决此问题的两种方法是:

  • 裁剪 axis=1 中的目标(x_train 的)维度以匹配 decode 的维度,即 1500。这是执行此操作的一种方法: history = model.fit(x_train1, x_train1[:,:-1,], batch_size=batch_size, ...)
  • Paddecode 获得的输出与(比如)0's,以匹配 x_train 的维度,即 1501。一种方法是在 decode 上使用 ZeroPadding2D 层: ZeroPadding2D(padding=((0,0),(0,1),(0,0)))(decode)