Flatten 层在 Keras 中如何工作?

How does the Flatten layer work in Keras?

我正在使用 TensorFlow 后端。

我正在依次应用卷积层、最大池化层、展平层和致密层。卷积需要一个 3D 输入(高度,宽度,color_channels_depth)。

卷积后变成(height,width,Number_of_filters).

应用最大池化后高度和宽度发生变化。但是,在应用展平层之后,到底发生了什么?比如flatten之前的输入是(24, 24, 32),那么它是怎么把它压平的?

它是像 (24 * 24) 一样按顺序排列高度、按顺序排列每个过滤器编号的重量,还是以其他方式排列?有实际值的示例将不胜感激。

像24*24*32这样连续的,按照下面的代码进行reshape。

def batch_flatten(x):
    """Turn a nD tensor into a 2D tensor with same 0th dimension.
    In other words, it flattens each data samples of a batch.
    # Arguments
        x: A tensor or variable.
    # Returns
        A tensor.
    """
    x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])]))
    return x

Flatten() 运算符展开从最后一个维度开始的值(至少对于 Theano,这是 "channels first",而不是像 TF 那样的 "channels last"。我不能 运行 TensorFlow 在我的环境中)。这相当于 numpy.reshape 和 'C' 排序:

‘C’ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest.

这是一个独立示例,说明了 Flatten 运算符与 Keras 函数 API。您应该能够轻松适应您的环境。

import numpy as np
from keras.layers import Input, Flatten
from keras.models import Model
inputs = Input(shape=(3,2,4))

# Define a model consisting only of the Flatten operation
prediction = Flatten()(inputs)
model = Model(inputs=inputs, outputs=prediction)

X = np.arange(0,24).reshape(1,3,2,4)
print(X)
#[[[[ 0  1  2  3]
#   [ 4  5  6  7]]
#
#  [[ 8  9 10 11]
#   [12 13 14 15]]
#
#  [[16 17 18 19]
#   [20 21 22 23]]]]
model.predict(X)
#array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
#         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
#         22.,  23.]], dtype=float32)

展平张量意味着删除除一个维度之外的所有维度。

Keras 中的 Flatten 层重塑张量,使其形状等于张量中包含的元素数。

这与制作元素的一维数组是一样的。

例如在VGG16模型中你可能会觉得很容易理解:

>>> model.summary()
Layer (type)                     Output Shape          Param #
================================================================
vgg16 (Model)                    (None, 4, 4, 512)     14714688
________________________________________________________________
flatten_1 (Flatten)              (None, 8192)          0
________________________________________________________________
dense_1 (Dense)                  (None, 256)           2097408
________________________________________________________________
dense_2 (Dense)                  (None, 1)             257
===============================================================

注意flatten_1层的形状是(None, 8192),其中8192实际上是4*4*512。


PS, None 表示 any 维度(或动态维度),但您通常可以将其读作 1。您可以在中找到更多详细信息.