使用 tensorflow Conv1D: 如何解决错误 "Input 0 of layer "conv1d_9" is incompatible with the layer: "?

Using tensorflow Conv1D: how can I solve error "Input 0 of layer "conv1d_9" is incompatible with the layer: "?

我正在使用TensorFlow对我模拟的超声波信号进行二进制分类,我想使用CNN。我是编程和机器学习的新手,所以我不知道我使用的术语是否正确,请耐心等待。 数据被组织成一个名为 'sig_data' 的数组,其中列是时间步长,行是不同的信号样本。这些值是信号的幅度。标签位于另一个名为 'sig_id' 的一维数组中,其中包含值 1 和 0。数据的形状如下:

data shape: (1000, 1000)
label shape: 1000

我已将数据放入 TF 数据集并分为训练集、验证集和测试集:

data_ds = tf.data.Dataset.from_tensors((sig_data, sig_id))

train_ds = data_ds.take(700)
val_ds = data_ds.skip(700).take(200)
test_ds = data_ds.skip(900).take(100)

train_ds = train_ds.shuffle(shuffle_buffer_size).batch(batch)
val_ds = val_ds.shuffle(shuffle_buffer_size).batch(batch)
test_ds = test_ds.batch(batch)

我创建的模型是:

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1000,1)),
    tf.keras.layers.Conv1D(50, 3, activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Dense(1, activation='sigmoid')
    ])

model.compile(
  optimizer='adam',
  loss='binary_crossentropy',
  metrics=['accuracy'])

history = model.fit(
  train_ds,
  validation_data=val_ds,
  batch_size=batch,
  epochs=25)

我收到以下错误:

ValueError: Exception encountered when calling layer "sequential_3" (type Sequential).
    
    Input 0 of layer "conv1d_3" is incompatible with the layer: expected axis -1 of input shape to have value 1, but received input with shape (None, 1000, 1000)

我已经查过这个问题并试图解决它。我认为问题出在输入形状上,所以我尝试按如下方式重塑我的数组:

sig_data_reshaped = np.expand_dims(sig_data, axis=-1)
sig_id_reshaped = np.expand_dims(sig_id, axis=-1)

reshaped data shape: (1000, 1000, 1)
reshaped label shape: (1000, 1)

但是当我运行我的代码时我仍然得到一个错误,

Input 0 of layer "conv1d_8" is incompatible with the layer: expected axis -1 of input shape to have value 1, but received input with shape (None, 1000, 1000)

我的错误是由于我组织数据集的方式造成的吗?为什么当我将数组重新整形为 3D 时,它仍然报错?

数据集 data_ds 包含一个形状为 (1000, 1000) 的记录。您可以尝试使用 from_tensor_slices 创建它。

data_ds = tf.data.Dataset.from_tensor_slices((sig_data, sig_id))