在 keras 中训练具有多个输入 3D 数组的 CNN

Training a CNN with multiple input 3D-arrays in keras

我需要用 42 次 CT 扫描的 (128x128x128) 块训练一个 3D_Unet 模型。

对于 CT 扫描和面罩,我的输入数据是 128x128x128。 我将数组的形状扩展为 (128, 128, 128, 1)。其中 1 是频道。

问题是如何将我的 40 个 4D 数组列表提供给模型?

如何将 model.fit() 或 model.train_on_batch 与我的模型中指定的正确输入形状一起使用?

project_name = '3D-Unet Segmentation of Lungs'
img_rows = 128
img_cols = 128
img_depth = 128
# smooth = 1
K.set_image_data_format('channels_last') 
#corresponds to inputs with shape:
#(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)

def get_unet():
    inputs = Input(shape=(img_depth, img_rows, img_cols, 1))
    conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

     ....


model = Model(inputs=[inputs], outputs=[conv10])

model.summary()
#plot_model(model, to_file='model.png')

model.compile(optimizer=Adam(lr=1e-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.000000199), 
              loss='binary_crossentropy', metrics=['accuracy'])
return model

用于数组列表作为输入

我应该在 .train_on_batch() 或 .fit() 中指定什么?

这是我在使用 .train_on_batch 选项时遇到的错误:

ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预期会看到 1 个数组,但却得到了以下 42 个数组的列表

model.train_on_batch(train_arrays_list, mask_arrays_list)

这是我在使用 .model.fit 选项增加 axis=0 的数组形状后出现的错误。

UnboundLocalError: 局部变量 'batch_index' 在赋值前被引用

model.fit(train_arrays_list[0], mask_arrays_list[0], 
          batch_size=1, 
          epochs=50, 
          verbose=1, 
          shuffle=True, 
          validation_split=0.10, 
          callbacks=[model_checkpoint, csv_logger])

您必须将形状为 (128, 128, 128, 1) 的 numpy 数组列表转换为形状为 (42, 128, 128, 128, 1) 的堆叠 5 维 numpy 数组。你可以这样做:model.fit(np.array(train_arrays_list), np.array(mask_arrays_list), batch_size=1, ...)