keras:使用一个模型输出作为另一个模型输入
keras: Use one model output as another model input
我在 InceptionResNetV2 模型(预训练)之前添加了一个密集层
这是 InceptionResNetV2 输出
model_base = InceptionResNetV2(include_top=True, weights='imagenet')
x = model_base.get_layer('avg_pool').output
x = Dense(3, activation='softmax')(x)
这是将添加的图层
input1 = Input(shape=input_shape1)
pre1 = Conv2D(filters=3, kernel_size=(5, 5), padding='SAME',
input_shape=input_shape1, name='first_dense')(input1)
pre = Model(inputs=input1, outputs=pre1)
这是结合了两个模型
after = Model(inputs=pre.output, outputs=x)
model = Model(inputs=input1, outputs=after.output)
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
使用
pre.output
作为
after.input
但这不起作用。我该如何解决?
首先让我们从 model_base 创建一个新模型,因为您希望获得更早的输出。
您的代码:
model_base = InceptionResNetV2(include_top=True, weights='imagenet')
x = model_base.get_layer('avg_pool').output
x = Dense(3, activation='softmax')(x)
新 model_base
:
model_base = Model(model_base.input, x)
现在,重要的是将输出 pre1
传递给此模型:
base_out = model_base(pre1)
就是这样:
model = Model(input1, base_out)
我在 InceptionResNetV2 模型(预训练)之前添加了一个密集层 这是 InceptionResNetV2 输出
model_base = InceptionResNetV2(include_top=True, weights='imagenet')
x = model_base.get_layer('avg_pool').output
x = Dense(3, activation='softmax')(x)
这是将添加的图层
input1 = Input(shape=input_shape1)
pre1 = Conv2D(filters=3, kernel_size=(5, 5), padding='SAME',
input_shape=input_shape1, name='first_dense')(input1)
pre = Model(inputs=input1, outputs=pre1)
这是结合了两个模型
after = Model(inputs=pre.output, outputs=x)
model = Model(inputs=input1, outputs=after.output)
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
使用
pre.output
作为
after.input
但这不起作用。我该如何解决?
首先让我们从 model_base 创建一个新模型,因为您希望获得更早的输出。
您的代码:
model_base = InceptionResNetV2(include_top=True, weights='imagenet')
x = model_base.get_layer('avg_pool').output
x = Dense(3, activation='softmax')(x)
新 model_base
:
model_base = Model(model_base.input, x)
现在,重要的是将输出 pre1
传递给此模型:
base_out = model_base(pre1)
就是这样:
model = Model(input1, base_out)