我未能训练 CNN + LSTM 模型。我怎么解决这个问题?数据集有问题吗?或模型? (Python 3.8 倍)
I failed to train CNN + LSTM model. How can I solve this problem? Is it have problem in dataset? or model? (Python 3.8x)
0。我用过:
Python 3.8x
JupyterLab >=3.0
张量流
凯拉斯
VGG19(预训练模型)
1。我的问题
我尝试训练 CNN + LSTM Python 模型进行视频分类(二进制分类)。
但是...我未能训练我的模型。 我的 JupyterLab(>=3.0) 只打印了 Epoch 1/100
并且几乎停止了,或者重新启动了内核(我建议可能内存不足,但我的桌面有 16GB RAM!)。
我做错模型了吗?还是我的数据集有问题?
另外,有时我会减少训练数据的大小。(2000 -> 100) 但问题并没有解决。
这是我的模型和数据集的结构。
2。输入数据形状(我的数据集)
数据:data_training_ar
- 类型:numpy 数组
- 形状:(2697, 30, 160, 160, 3)
它有2697个视频的160*160大小的RGB ndarray。每个视频有30帧。
- 示例:data_training_ar[10]
array([[[[0.03105 , 0.02397 , 0.02713 ],
[0.08167 , 0.0738 , 0.0777 ],
[0.1142 , 0.1064 , 0.1103 ],
...,
[0.183 , 0.1752 , 0.1713 ],
[0.12427 , 0.11646 , 0.1137 ],
[0.01765 , 0.0098 , 0.00784 ]],
[[0.1113 , 0.1051 , 0.1074 ],
[0.5225 , 0.5146 , 0.5186 ],
[0.3794 , 0.3713 , 0.3755 ],
...,
[0.2229 , 0.2151 , 0.2112 ],
[0.1255 , 0.1177 , 0.1137 ],
[0.013725, 0.00816 , 0.005882]],
[[0.124 , 0.11615 , 0.1201 ],
[0.4556 , 0.4478 , 0.4517 ],
[0.3982 , 0.3904 , 0.3943 ],
...,
[0.1613 , 0.1534 , 0.1495 ],
[0.1173 , 0.10956 , 0.1075 ],
[0.0098 , 0.005882, 0.005882]],
...,
[[0.08453 , 0.08246 , 0.08246 ],
[0.4902 , 0.498 , 0.4863 ],
[0.5337 , 0.5728 , 0.5337 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.08234 , 0.0807 , 0.08466 ],
[0.482 , 0.4941 , 0.4883 ],
[0.51 , 0.554 , 0.521 ],
...,
[0.9766 , 0.9766 , 0.9766 ],
[0.9663 , 0.9663 , 0.9663 ],
[0.9683 , 0.9683 , 0.9683 ]],
[[0.08234 , 0.0843 , 0.0863 ],
[0.4824 , 0.496 , 0.4902 ],
[0.51 , 0.551 , 0.5195 ],
...,
[0.4133 , 0.4133 , 0.4133 ],
[0.3955 , 0.3955 , 0.3955 ],
[0.3523 , 0.3523 , 0.3523 ]]],
[[[0.01689 , 0.01221 , 0.01296 ],
[0.0955 , 0.08765 , 0.09155 ],
[0.1139 , 0.1061 , 0.11 ],
...,
[0.179 , 0.1711 , 0.1672 ],
[0.12354 , 0.11566 , 0.11255 ],
[0.01645 , 0.0098 , 0.0098 ]],
[[0.11365 , 0.10583 , 0.10974 ],
[0.5186 , 0.5107 , 0.5146 ],
[0.3809 , 0.373 , 0.377 ],
...,
[0.232 , 0.2242 , 0.2203 ],
[0.1232 , 0.11566 , 0.11176 ],
[0.013725, 0.0098 , 0.00784 ]],
[[0.135 , 0.1274 , 0.1311 ],
[0.4604 , 0.4526 , 0.4565 ],
[0.3862 , 0.3784 , 0.3823 ],
...,
[0.1727 , 0.1648 , 0.1609 ],
[0.11115 , 0.10333 , 0.09937 ],
[0.013725, 0.00784 , 0.005882]],
...,
[[0.07855 , 0.0787 , 0.0745 ],
[0.4788 , 0.4963 , 0.4785 ],
[0.5337 , 0.563 , 0.5317 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.0745 , 0.0804 , 0.0784 ],
[0.4727 , 0.496 , 0.4805 ],
[0.5137 , 0.551 , 0.5254 ],
...,
[0.9717 , 0.9717 , 0.9717 ],
[0.974 , 0.974 , 0.974 ],
[0.973 , 0.973 , 0.973 ]],
[[0.0745 , 0.08234 , 0.0804 ],
[0.4727 , 0.498 , 0.4844 ],
[0.5137 , 0.551 , 0.5254 ],
...,
[0.4067 , 0.4067 , 0.4067 ],
[0.3923 , 0.3923 , 0.3923 ],
[0.3586 , 0.3586 , 0.3586 ]]],
[[[0.01689 , 0.01025 , 0.01296 ],
[0.09265 , 0.07965 , 0.0836 ],
[0.12445 , 0.1053 , 0.11 ],
...,
[0.172 , 0.1674 , 0.1635 ],
[0.111 , 0.1149 , 0.10706 ],
[0.00784 , 0.008606, 0.00784 ]],
[[0.1068 , 0.0996 , 0.1029 ],
[0.522 , 0.5117 , 0.5156 ],
[0.3933 , 0.3755 , 0.3813 ],
...,
[0.2363 , 0.2305 , 0.2249 ],
[0.1209 , 0.1213 , 0.1134 ],
[0.00784 , 0.00948 , 0.00784 ]],
[[0.1294 , 0.1239 , 0.1257 ],
[0.4658 , 0.4563 , 0.46 ],
[0.395 , 0.3796 , 0.3835 ],
...,
[0.1705 , 0.1627 , 0.1588 ],
[0.1207 , 0.11676 , 0.111 ],
[0.00968 , 0.00968 , 0.005882]],
...,
[[0.0726 , 0.0784 , 0.0727 ],
[0.471 , 0.4963 , 0.4749 ],
[0.528 , 0.565 , 0.5317 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.0745 , 0.0804 , 0.0784 ],
[0.4727 , 0.496 , 0.4805 ],
[0.517 , 0.5547 , 0.5293 ],
...,
[0.977 , 0.977 , 0.977 ],
[0.9707 , 0.9707 , 0.9707 ],
[0.9766 , 0.9766 , 0.9766 ]],
[[0.0745 , 0.08234 , 0.08234 ],
[0.4746 , 0.498 , 0.4844 ],
[0.5137 , 0.5527 , 0.5254 ],
...,
[0.4087 , 0.4087 , 0.4087 ],
[0.3977 , 0.3977 , 0.3977 ],
[0.3484 , 0.3484 , 0.3484 ]]],
...,
[[[0.01778 , 0.01778 , 0.01778 ],
[0.08307 , 0.08307 , 0.08307 ],
[0.1046 , 0.1046 , 0.1046 ],
...,
[0.1659 , 0.1744 , 0.1631 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0795 , 0.0795 , 0.0795 ],
[0.4434 , 0.4434 , 0.4434 ],
[0.3796 , 0.3796 , 0.3796 ],
...,
[0.2612 , 0.2708 , 0.2573 ],
[0.1079 , 0.1157 , 0.11017 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0664 , 0.0664 , 0.0664 ],
[0.3572 , 0.3572 , 0.3572 ],
[0.388 , 0.388 , 0.388 ],
...,
[0.1753 , 0.1792 , 0.1674 ],
[0.1013 , 0.1054 , 0.10144 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]],
[[[0.01581 , 0.01581 , 0.01581 ],
[0.0835 , 0.0835 , 0.0835 ],
[0.1042 , 0.1042 , 0.1042 ],
...,
[0.1631 , 0.1725 , 0.1611 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.07623 , 0.07623 , 0.07623 ],
[0.442 , 0.442 , 0.442 ],
[0.3748 , 0.3748 , 0.3748 ],
...,
[0.2605 , 0.269 , 0.257 ],
[0.1082 , 0.116 , 0.1118 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0646 , 0.0646 , 0.0646 ],
[0.3538 , 0.3538 , 0.3538 ],
[0.3918 , 0.3918 , 0.3918 ],
...,
[0.1735 , 0.1792 , 0.1655 ],
[0.1013 , 0.10724 , 0.10333 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]],
[[[0.01581 , 0.01581 , 0.01581 ],
[0.0835 , 0.0835 , 0.0835 ],
[0.1042 , 0.1042 , 0.1042 ],
...,
[0.1624 , 0.1709 , 0.1592 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.07623 , 0.07623 , 0.07623 ],
[0.442 , 0.442 , 0.442 ],
[0.3748 , 0.3748 , 0.3748 ],
...,
[0.2646 , 0.2747 , 0.261 ],
[0.1082 , 0.116 , 0.11017 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0646 , 0.0646 , 0.0646 ],
[0.3538 , 0.3538 , 0.3538 ],
[0.3918 , 0.3918 , 0.3918 ],
...,
[0.1755 , 0.1792 , 0.1674 ],
[0.1013 , 0.1054 , 0.10144 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]]], dtype=float16)
目标:label_training_ar
类型:numpy 数组
形状:(2697, 30, 2)
示例:label_training_ar[10]
array([[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.]])
3。 VGG19 + LSTM 模型
3-1。代码
base_model=keras.applications.VGG19(include_top=False, input_shape=(160, 160, 3), weights='imagenet')
image_model=keras.models.Sequential()
image_model.add(base_model)
image_model.add(keras.layers.Flatten())
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc1'))
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc2'))
image_model.add(keras.layers.Dense(1000, activation='softmax', name='predictions'))
chunk_size=4096
n_chunks=30
rnn_size=512
model=keras.models.Sequential()
model.add(keras.layers.TimeDistributed(image_model, input_shape=(30, 160, 160, 3)))
model.add(keras.layers.LSTM(rnn_size, input_shape=(n_chunks, chunk_size))) # (30, 4096)
model.add(keras.layers.Dense(1024))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(256))
model.add(keras.layers.Activation('sigmoid'))
model.add(keras.layers.Dense(2))
model.add(keras.layers.Activation('softmax'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
3-2。绘图图像
4。模型拟合(模型训练)
epoch=100
batchS=30
history=model.fit(x=data_training_ar[0:2000], y=label_training_ar[0:2000], epochs=epoch,
validation_data=(data_training_ar[2000:], label_training_ar[2000:]),
callbacks=[checkpoint_cb], #keras.callbacks.ModelCheckpoint('210429_vc_13-02_checkpoint.h5', save_best_only=True)
batch_size=batchS, verbose=2)
尝试直接在命令行上使用 Spyder 或记事本和 运行 您的脚本。这是为了确保您的问题与 Web 服务器 运行ning Jupyter 的某些超时无关。它还将允许您查看完整的堆栈跟踪。
如果可能,请尝试使用 PyCharm 并查看错误是否仍然存在?还要检查它是否是相同的错误。
我 运行 VGG 系列模型在 Google Colab 中。它相当快。
0。我用过:
Python 3.8x
JupyterLab >=3.0
张量流
凯拉斯
VGG19(预训练模型)
1。我的问题
我尝试训练 CNN + LSTM Python 模型进行视频分类(二进制分类)。
但是...我未能训练我的模型。 我的 JupyterLab(>=3.0) 只打印了 Epoch 1/100
并且几乎停止了,或者重新启动了内核(我建议可能内存不足,但我的桌面有 16GB RAM!)。
我做错模型了吗?还是我的数据集有问题?
另外,有时我会减少训练数据的大小。(2000 -> 100) 但问题并没有解决。
这是我的模型和数据集的结构。
2。输入数据形状(我的数据集)
数据:data_training_ar
- 类型:numpy 数组
- 形状:(2697, 30, 160, 160, 3)
它有2697个视频的160*160大小的RGB ndarray。每个视频有30帧。
- 示例:data_training_ar[10]
array([[[[0.03105 , 0.02397 , 0.02713 ],
[0.08167 , 0.0738 , 0.0777 ],
[0.1142 , 0.1064 , 0.1103 ],
...,
[0.183 , 0.1752 , 0.1713 ],
[0.12427 , 0.11646 , 0.1137 ],
[0.01765 , 0.0098 , 0.00784 ]],
[[0.1113 , 0.1051 , 0.1074 ],
[0.5225 , 0.5146 , 0.5186 ],
[0.3794 , 0.3713 , 0.3755 ],
...,
[0.2229 , 0.2151 , 0.2112 ],
[0.1255 , 0.1177 , 0.1137 ],
[0.013725, 0.00816 , 0.005882]],
[[0.124 , 0.11615 , 0.1201 ],
[0.4556 , 0.4478 , 0.4517 ],
[0.3982 , 0.3904 , 0.3943 ],
...,
[0.1613 , 0.1534 , 0.1495 ],
[0.1173 , 0.10956 , 0.1075 ],
[0.0098 , 0.005882, 0.005882]],
...,
[[0.08453 , 0.08246 , 0.08246 ],
[0.4902 , 0.498 , 0.4863 ],
[0.5337 , 0.5728 , 0.5337 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.08234 , 0.0807 , 0.08466 ],
[0.482 , 0.4941 , 0.4883 ],
[0.51 , 0.554 , 0.521 ],
...,
[0.9766 , 0.9766 , 0.9766 ],
[0.9663 , 0.9663 , 0.9663 ],
[0.9683 , 0.9683 , 0.9683 ]],
[[0.08234 , 0.0843 , 0.0863 ],
[0.4824 , 0.496 , 0.4902 ],
[0.51 , 0.551 , 0.5195 ],
...,
[0.4133 , 0.4133 , 0.4133 ],
[0.3955 , 0.3955 , 0.3955 ],
[0.3523 , 0.3523 , 0.3523 ]]],
[[[0.01689 , 0.01221 , 0.01296 ],
[0.0955 , 0.08765 , 0.09155 ],
[0.1139 , 0.1061 , 0.11 ],
...,
[0.179 , 0.1711 , 0.1672 ],
[0.12354 , 0.11566 , 0.11255 ],
[0.01645 , 0.0098 , 0.0098 ]],
[[0.11365 , 0.10583 , 0.10974 ],
[0.5186 , 0.5107 , 0.5146 ],
[0.3809 , 0.373 , 0.377 ],
...,
[0.232 , 0.2242 , 0.2203 ],
[0.1232 , 0.11566 , 0.11176 ],
[0.013725, 0.0098 , 0.00784 ]],
[[0.135 , 0.1274 , 0.1311 ],
[0.4604 , 0.4526 , 0.4565 ],
[0.3862 , 0.3784 , 0.3823 ],
...,
[0.1727 , 0.1648 , 0.1609 ],
[0.11115 , 0.10333 , 0.09937 ],
[0.013725, 0.00784 , 0.005882]],
...,
[[0.07855 , 0.0787 , 0.0745 ],
[0.4788 , 0.4963 , 0.4785 ],
[0.5337 , 0.563 , 0.5317 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.0745 , 0.0804 , 0.0784 ],
[0.4727 , 0.496 , 0.4805 ],
[0.5137 , 0.551 , 0.5254 ],
...,
[0.9717 , 0.9717 , 0.9717 ],
[0.974 , 0.974 , 0.974 ],
[0.973 , 0.973 , 0.973 ]],
[[0.0745 , 0.08234 , 0.0804 ],
[0.4727 , 0.498 , 0.4844 ],
[0.5137 , 0.551 , 0.5254 ],
...,
[0.4067 , 0.4067 , 0.4067 ],
[0.3923 , 0.3923 , 0.3923 ],
[0.3586 , 0.3586 , 0.3586 ]]],
[[[0.01689 , 0.01025 , 0.01296 ],
[0.09265 , 0.07965 , 0.0836 ],
[0.12445 , 0.1053 , 0.11 ],
...,
[0.172 , 0.1674 , 0.1635 ],
[0.111 , 0.1149 , 0.10706 ],
[0.00784 , 0.008606, 0.00784 ]],
[[0.1068 , 0.0996 , 0.1029 ],
[0.522 , 0.5117 , 0.5156 ],
[0.3933 , 0.3755 , 0.3813 ],
...,
[0.2363 , 0.2305 , 0.2249 ],
[0.1209 , 0.1213 , 0.1134 ],
[0.00784 , 0.00948 , 0.00784 ]],
[[0.1294 , 0.1239 , 0.1257 ],
[0.4658 , 0.4563 , 0.46 ],
[0.395 , 0.3796 , 0.3835 ],
...,
[0.1705 , 0.1627 , 0.1588 ],
[0.1207 , 0.11676 , 0.111 ],
[0.00968 , 0.00968 , 0.005882]],
...,
[[0.0726 , 0.0784 , 0.0727 ],
[0.471 , 0.4963 , 0.4749 ],
[0.528 , 0.565 , 0.5317 ],
...,
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ],
[0.9883 , 0.9883 , 0.9883 ]],
[[0.0745 , 0.0804 , 0.0784 ],
[0.4727 , 0.496 , 0.4805 ],
[0.517 , 0.5547 , 0.5293 ],
...,
[0.977 , 0.977 , 0.977 ],
[0.9707 , 0.9707 , 0.9707 ],
[0.9766 , 0.9766 , 0.9766 ]],
[[0.0745 , 0.08234 , 0.08234 ],
[0.4746 , 0.498 , 0.4844 ],
[0.5137 , 0.5527 , 0.5254 ],
...,
[0.4087 , 0.4087 , 0.4087 ],
[0.3977 , 0.3977 , 0.3977 ],
[0.3484 , 0.3484 , 0.3484 ]]],
...,
[[[0.01778 , 0.01778 , 0.01778 ],
[0.08307 , 0.08307 , 0.08307 ],
[0.1046 , 0.1046 , 0.1046 ],
...,
[0.1659 , 0.1744 , 0.1631 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0795 , 0.0795 , 0.0795 ],
[0.4434 , 0.4434 , 0.4434 ],
[0.3796 , 0.3796 , 0.3796 ],
...,
[0.2612 , 0.2708 , 0.2573 ],
[0.1079 , 0.1157 , 0.11017 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0664 , 0.0664 , 0.0664 ],
[0.3572 , 0.3572 , 0.3572 ],
[0.388 , 0.388 , 0.388 ],
...,
[0.1753 , 0.1792 , 0.1674 ],
[0.1013 , 0.1054 , 0.10144 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]],
[[[0.01581 , 0.01581 , 0.01581 ],
[0.0835 , 0.0835 , 0.0835 ],
[0.1042 , 0.1042 , 0.1042 ],
...,
[0.1631 , 0.1725 , 0.1611 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.07623 , 0.07623 , 0.07623 ],
[0.442 , 0.442 , 0.442 ],
[0.3748 , 0.3748 , 0.3748 ],
...,
[0.2605 , 0.269 , 0.257 ],
[0.1082 , 0.116 , 0.1118 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0646 , 0.0646 , 0.0646 ],
[0.3538 , 0.3538 , 0.3538 ],
[0.3918 , 0.3918 , 0.3918 ],
...,
[0.1735 , 0.1792 , 0.1655 ],
[0.1013 , 0.10724 , 0.10333 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]],
[[[0.01581 , 0.01581 , 0.01581 ],
[0.0835 , 0.0835 , 0.0835 ],
[0.1042 , 0.1042 , 0.1042 ],
...,
[0.1624 , 0.1709 , 0.1592 ],
[0.08594 , 0.0938 , 0.0899 ],
[0.00784 , 0.011765, 0.011765]],
[[0.07623 , 0.07623 , 0.07623 ],
[0.442 , 0.442 , 0.442 ],
[0.3748 , 0.3748 , 0.3748 ],
...,
[0.2646 , 0.2747 , 0.261 ],
[0.1082 , 0.116 , 0.11017 ],
[0.00784 , 0.011765, 0.011765]],
[[0.0646 , 0.0646 , 0.0646 ],
[0.3538 , 0.3538 , 0.3538 ],
[0.3918 , 0.3918 , 0.3918 ],
...,
[0.1755 , 0.1792 , 0.1674 ],
[0.1013 , 0.1054 , 0.10144 ],
[0.00772 , 0.01164 , 0.01152 ]],
...,
[[0.08234 , 0.0784 , 0.0844 ],
[0.512 , 0.512 , 0.516 ],
[0.563 , 0.563 , 0.557 ],
...,
[0.9844 , 0.9844 , 0.9844 ],
[0.9883 , 0.988 , 0.9883 ],
[0.9883 , 0.988 , 0.9883 ]],
[[0.0843 , 0.08264 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.549 , 0.549 , 0.549 ],
...,
[0.972 , 0.972 , 0.972 ],
[0.9766 , 0.9766 , 0.9766 ],
[0.9727 , 0.9727 , 0.9727 ]],
[[0.0843 , 0.0843 , 0.0843 ],
[0.506 , 0.506 , 0.506 ],
[0.547 , 0.547 , 0.547 ],
...,
[0.4138 , 0.4138 , 0.4138 ],
[0.3975 , 0.3975 , 0.3975 ],
[0.3496 , 0.3496 , 0.3496 ]]]], dtype=float16)
目标:label_training_ar
类型:numpy 数组
形状:(2697, 30, 2)
示例:label_training_ar[10]
array([[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.]])
3。 VGG19 + LSTM 模型
3-1。代码
base_model=keras.applications.VGG19(include_top=False, input_shape=(160, 160, 3), weights='imagenet')
image_model=keras.models.Sequential()
image_model.add(base_model)
image_model.add(keras.layers.Flatten())
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc1'))
image_model.add(keras.layers.Dense(4096, activation='relu', name='fc2'))
image_model.add(keras.layers.Dense(1000, activation='softmax', name='predictions'))
chunk_size=4096
n_chunks=30
rnn_size=512
model=keras.models.Sequential()
model.add(keras.layers.TimeDistributed(image_model, input_shape=(30, 160, 160, 3)))
model.add(keras.layers.LSTM(rnn_size, input_shape=(n_chunks, chunk_size))) # (30, 4096)
model.add(keras.layers.Dense(1024))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(256))
model.add(keras.layers.Activation('sigmoid'))
model.add(keras.layers.Dense(2))
model.add(keras.layers.Activation('softmax'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
3-2。绘图图像
4。模型拟合(模型训练)
epoch=100
batchS=30
history=model.fit(x=data_training_ar[0:2000], y=label_training_ar[0:2000], epochs=epoch,
validation_data=(data_training_ar[2000:], label_training_ar[2000:]),
callbacks=[checkpoint_cb], #keras.callbacks.ModelCheckpoint('210429_vc_13-02_checkpoint.h5', save_best_only=True)
batch_size=batchS, verbose=2)
尝试直接在命令行上使用 Spyder 或记事本和 运行 您的脚本。这是为了确保您的问题与 Web 服务器 运行ning Jupyter 的某些超时无关。它还将允许您查看完整的堆栈跟踪。
如果可能,请尝试使用 PyCharm 并查看错误是否仍然存在?还要检查它是否是相同的错误。
我 运行 VGG 系列模型在 Google Colab 中。它相当快。