keras.layers.concatenate 是做什么的

What does keras.layers.concatenate do

我遇到了 the following code,想知道 keras.layers.concatenate 在这种情况下到底做了什么。

最佳猜测:

  1. fire_module()中,y基于每个像素进行学习(kernel_size=1)
  2. y1根据y(kernel_size=1)
  3. activation map的每个像素进行学习
  4. y3基于y(kernel_size=3)
  5. activation map的3x3像素区域进行学习
  6. concatenatey1y3 放在一起,这意味着总数 filters 现在是 y1y3[ 中过滤器的总和=54=]
  7. 这种连接是基于每个像素的学习,基于 3x3 的学习,两者都基于之前基于每个像素的激活图,使模型更好?

非常感谢任何帮助。

def fire(x, squeeze, expand):
    y  = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
    y  = BatchNormalization(momentum=bnmomemtum)(y)
    y1 = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
    y1 = BatchNormalization(momentum=bnmomemtum)(y1)
    y3 = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
    y3 = BatchNormalization(momentum=bnmomemtum)(y3)
    return concatenate([y1, y3])

def fire_module(squeeze, expand):
    return lambda x: fire(x, squeeze, expand)
x = Input(shape=[144, 144, 3])
y = BatchNormalization(center=True, scale=False)(x)
y = Activation('relu')(y)
y = Conv2D(kernel_size=5, filters=16, padding='same', use_bias=True, activation='relu')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)

y = fire_module(16, 32)(y)
y = MaxPooling2D(pool_size=2)(y)

编辑:

更具体一点,为什么不用这个:

# why not this?
def fire(x, squeeze, expand):
    y  = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
    y  = BatchNormalization(momentum=bnmomemtum)(y)
    y = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
    y = BatchNormalization(momentum=bnmomemtum)(y)
    y = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
    y = BatchNormalization(momentum=bnmomemtum)(y)
    return y

当他解释串联时,我从这个 中引用了 @parsethis,如果将 a 串联到 b(结果连接在一起),它就是这样做的:

    a        b         c
a b c   g h i    a b c g h i
d e f   j k l    d e f j k l

The documentation 表示它只是 returns 一个包含所有输入串联的张量,前提是它们共享一个维度(即相同的长度或宽度,取决于轴)

你的情况看起来是这样的:

Y 
 \
  Y1----
   \    |
    Y3  Y1

希望我说得够清楚