我未能训练 CNN + LSTM 模型。我怎么解决这个问题?数据集有问题吗?或模型? (Python 3.8 倍)

I failed to train CNN + LSTM model. How can I solve this problem? Is it have problem in dataset? or model? (Python 3.8x)

0。我用过:

1。我的问题

我尝试训练 CNN + LSTM Python 模型进行视频分类(二进制分类)。

但是...我未能训练我的模型。 我的 JupyterLab(>=3.0) 只打印了 Epoch 1/100 并且几乎停止了,或者重新启动了内核(我建议可能内存不足,但我的桌面有 16GB RAM!)。

我做错模型了吗?还是我的数据集有问题?

另外,有时我会减少训练数据的大小。(2000 -> 100) 但问题并没有解决。

这是我的模型和数据集的结构。

2。输入数据形状(我的数据集)

数据:data_training_ar

它有2697个视频的160*160大小的RGB ndarray。每个视频有30帧。

array([[[[0.03105 , 0.02397 , 0.02713 ],
         [0.08167 , 0.0738  , 0.0777  ],
         [0.1142  , 0.1064  , 0.1103  ],
         ...,
         [0.183   , 0.1752  , 0.1713  ],
         [0.12427 , 0.11646 , 0.1137  ],
         [0.01765 , 0.0098  , 0.00784 ]],

        [[0.1113  , 0.1051  , 0.1074  ],
         [0.5225  , 0.5146  , 0.5186  ],
         [0.3794  , 0.3713  , 0.3755  ],
         ...,
         [0.2229  , 0.2151  , 0.2112  ],
         [0.1255  , 0.1177  , 0.1137  ],
         [0.013725, 0.00816 , 0.005882]],

        [[0.124   , 0.11615 , 0.1201  ],
         [0.4556  , 0.4478  , 0.4517  ],
         [0.3982  , 0.3904  , 0.3943  ],
         ...,
         [0.1613  , 0.1534  , 0.1495  ],
         [0.1173  , 0.10956 , 0.1075  ],
         [0.0098  , 0.005882, 0.005882]],

        ...,

        [[0.08453 , 0.08246 , 0.08246 ],
         [0.4902  , 0.498   , 0.4863  ],
         [0.5337  , 0.5728  , 0.5337  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.08234 , 0.0807  , 0.08466 ],
         [0.482   , 0.4941  , 0.4883  ],
         [0.51    , 0.554   , 0.521   ],
         ...,
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9663  , 0.9663  , 0.9663  ],
         [0.9683  , 0.9683  , 0.9683  ]],

        [[0.08234 , 0.0843  , 0.0863  ],
         [0.4824  , 0.496   , 0.4902  ],
         [0.51    , 0.551   , 0.5195  ],
         ...,
         [0.4133  , 0.4133  , 0.4133  ],
         [0.3955  , 0.3955  , 0.3955  ],
         [0.3523  , 0.3523  , 0.3523  ]]],


       [[[0.01689 , 0.01221 , 0.01296 ],
         [0.0955  , 0.08765 , 0.09155 ],
         [0.1139  , 0.1061  , 0.11    ],
         ...,
         [0.179   , 0.1711  , 0.1672  ],
         [0.12354 , 0.11566 , 0.11255 ],
         [0.01645 , 0.0098  , 0.0098  ]],

        [[0.11365 , 0.10583 , 0.10974 ],
         [0.5186  , 0.5107  , 0.5146  ],
         [0.3809  , 0.373   , 0.377   ],
         ...,
         [0.232   , 0.2242  , 0.2203  ],
         [0.1232  , 0.11566 , 0.11176 ],
         [0.013725, 0.0098  , 0.00784 ]],

        [[0.135   , 0.1274  , 0.1311  ],
         [0.4604  , 0.4526  , 0.4565  ],
         [0.3862  , 0.3784  , 0.3823  ],
         ...,
         [0.1727  , 0.1648  , 0.1609  ],
         [0.11115 , 0.10333 , 0.09937 ],
         [0.013725, 0.00784 , 0.005882]],

        ...,

        [[0.07855 , 0.0787  , 0.0745  ],
         [0.4788  , 0.4963  , 0.4785  ],
         [0.5337  , 0.563   , 0.5317  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.0745  , 0.0804  , 0.0784  ],
         [0.4727  , 0.496   , 0.4805  ],
         [0.5137  , 0.551   , 0.5254  ],
         ...,
         [0.9717  , 0.9717  , 0.9717  ],
         [0.974   , 0.974   , 0.974   ],
         [0.973   , 0.973   , 0.973   ]],

        [[0.0745  , 0.08234 , 0.0804  ],
         [0.4727  , 0.498   , 0.4844  ],
         [0.5137  , 0.551   , 0.5254  ],
         ...,
         [0.4067  , 0.4067  , 0.4067  ],
         [0.3923  , 0.3923  , 0.3923  ],
         [0.3586  , 0.3586  , 0.3586  ]]],


       [[[0.01689 , 0.01025 , 0.01296 ],
         [0.09265 , 0.07965 , 0.0836  ],
         [0.12445 , 0.1053  , 0.11    ],
         ...,
         [0.172   , 0.1674  , 0.1635  ],
         [0.111   , 0.1149  , 0.10706 ],
         [0.00784 , 0.008606, 0.00784 ]],

        [[0.1068  , 0.0996  , 0.1029  ],
         [0.522   , 0.5117  , 0.5156  ],
         [0.3933  , 0.3755  , 0.3813  ],
         ...,
         [0.2363  , 0.2305  , 0.2249  ],
         [0.1209  , 0.1213  , 0.1134  ],
         [0.00784 , 0.00948 , 0.00784 ]],

        [[0.1294  , 0.1239  , 0.1257  ],
         [0.4658  , 0.4563  , 0.46    ],
         [0.395   , 0.3796  , 0.3835  ],
         ...,
         [0.1705  , 0.1627  , 0.1588  ],
         [0.1207  , 0.11676 , 0.111   ],
         [0.00968 , 0.00968 , 0.005882]],

        ...,

        [[0.0726  , 0.0784  , 0.0727  ],
         [0.471   , 0.4963  , 0.4749  ],
         [0.528   , 0.565   , 0.5317  ],
         ...,
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ],
         [0.9883  , 0.9883  , 0.9883  ]],

        [[0.0745  , 0.0804  , 0.0784  ],
         [0.4727  , 0.496   , 0.4805  ],
         [0.517   , 0.5547  , 0.5293  ],
         ...,
         [0.977   , 0.977   , 0.977   ],
         [0.9707  , 0.9707  , 0.9707  ],
         [0.9766  , 0.9766  , 0.9766  ]],

        [[0.0745  , 0.08234 , 0.08234 ],
         [0.4746  , 0.498   , 0.4844  ],
         [0.5137  , 0.5527  , 0.5254  ],
         ...,
         [0.4087  , 0.4087  , 0.4087  ],
         [0.3977  , 0.3977  , 0.3977  ],
         [0.3484  , 0.3484  , 0.3484  ]]],


       ...,


       [[[0.01778 , 0.01778 , 0.01778 ],
         [0.08307 , 0.08307 , 0.08307 ],
         [0.1046  , 0.1046  , 0.1046  ],
         ...,
         [0.1659  , 0.1744  , 0.1631  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0795  , 0.0795  , 0.0795  ],
         [0.4434  , 0.4434  , 0.4434  ],
         [0.3796  , 0.3796  , 0.3796  ],
         ...,
         [0.2612  , 0.2708  , 0.2573  ],
         [0.1079  , 0.1157  , 0.11017 ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0664  , 0.0664  , 0.0664  ],
         [0.3572  , 0.3572  , 0.3572  ],
         [0.388   , 0.388   , 0.388   ],
         ...,
         [0.1753  , 0.1792  , 0.1674  ],
         [0.1013  , 0.1054  , 0.10144 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]],


       [[[0.01581 , 0.01581 , 0.01581 ],
         [0.0835  , 0.0835  , 0.0835  ],
         [0.1042  , 0.1042  , 0.1042  ],
         ...,
         [0.1631  , 0.1725  , 0.1611  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.07623 , 0.07623 , 0.07623 ],
         [0.442   , 0.442   , 0.442   ],
         [0.3748  , 0.3748  , 0.3748  ],
         ...,
         [0.2605  , 0.269   , 0.257   ],
         [0.1082  , 0.116   , 0.1118  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0646  , 0.0646  , 0.0646  ],
         [0.3538  , 0.3538  , 0.3538  ],
         [0.3918  , 0.3918  , 0.3918  ],
         ...,
         [0.1735  , 0.1792  , 0.1655  ],
         [0.1013  , 0.10724 , 0.10333 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]],


       [[[0.01581 , 0.01581 , 0.01581 ],
         [0.0835  , 0.0835  , 0.0835  ],
         [0.1042  , 0.1042  , 0.1042  ],
         ...,
         [0.1624  , 0.1709  , 0.1592  ],
         [0.08594 , 0.0938  , 0.0899  ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.07623 , 0.07623 , 0.07623 ],
         [0.442   , 0.442   , 0.442   ],
         [0.3748  , 0.3748  , 0.3748  ],
         ...,
         [0.2646  , 0.2747  , 0.261   ],
         [0.1082  , 0.116   , 0.11017 ],
         [0.00784 , 0.011765, 0.011765]],

        [[0.0646  , 0.0646  , 0.0646  ],
         [0.3538  , 0.3538  , 0.3538  ],
         [0.3918  , 0.3918  , 0.3918  ],
         ...,
         [0.1755  , 0.1792  , 0.1674  ],
         [0.1013  , 0.1054  , 0.10144 ],
         [0.00772 , 0.01164 , 0.01152 ]],

        ...,

        [[0.08234 , 0.0784  , 0.0844  ],
         [0.512   , 0.512   , 0.516   ],
         [0.563   , 0.563   , 0.557   ],
         ...,
         [0.9844  , 0.9844  , 0.9844  ],
         [0.9883  , 0.988   , 0.9883  ],
         [0.9883  , 0.988   , 0.9883  ]],

        [[0.0843  , 0.08264 , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.549   , 0.549   , 0.549   ],
         ...,
         [0.972   , 0.972   , 0.972   ],
         [0.9766  , 0.9766  , 0.9766  ],
         [0.9727  , 0.9727  , 0.9727  ]],

        [[0.0843  , 0.0843  , 0.0843  ],
         [0.506   , 0.506   , 0.506   ],
         [0.547   , 0.547   , 0.547   ],
         ...,
         [0.4138  , 0.4138  , 0.4138  ],
         [0.3975  , 0.3975  , 0.3975  ],
         [0.3496  , 0.3496  , 0.3496  ]]]], dtype=float16)

目标:label_training_ar

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.]])

3。 VGG19 + LSTM 模型

3-1。代码

base_model=keras.applications.VGG19(include_top=False, input_shape=(160, 160, 3), weights='imagenet')

image_model=keras.models.Sequential()
image_model.add(base_model)
image_model.add(keras.layers.Flatten())
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc1'))
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc2'))
image_model.add(keras.layers.Dense(1000, activation='softmax', name='predictions'))

chunk_size=4096
n_chunks=30
rnn_size=512

model=keras.models.Sequential()
model.add(keras.layers.TimeDistributed(image_model, input_shape=(30, 160, 160, 3)))

model.add(keras.layers.LSTM(rnn_size, input_shape=(n_chunks, chunk_size))) # (30, 4096)
model.add(keras.layers.Dense(1024))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(256))
model.add(keras.layers.Activation('sigmoid'))
model.add(keras.layers.Dense(2))
model.add(keras.layers.Activation('softmax'))

model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])

3-2。绘图图像

4。模型拟合(模型训练)

epoch=100
batchS=30
history=model.fit(x=data_training_ar[0:2000], y=label_training_ar[0:2000], epochs=epoch,
                  validation_data=(data_training_ar[2000:], label_training_ar[2000:]),
                  callbacks=[checkpoint_cb], #keras.callbacks.ModelCheckpoint('210429_vc_13-02_checkpoint.h5', save_best_only=True)
                  batch_size=batchS, verbose=2)

尝试直接在命令行上使用 Spyder 或记事本和 运行 您的脚本。这是为了确保您的问题与 Web 服务器 运行ning Jupyter 的某些超时无关。它还将允许您查看完整的堆栈跟踪。

如果可能,请尝试使用 PyCharm 并查看错误是否仍然存在?还要检查它是否是相同的错误。

我 运行 VGG 系列模型在 Google Colab 中。它相当快。