tf.keras 替换预训练 resnet50 中的下层

tf.keras replace lower layer in pretrained resnet50

是否可以在 tf.keras.applications 中 remove/replace 预训练 ResNet50 模型的底层?

例如,我试过这样做:

import tensorflow as tf
pretrained_resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet')
inputs = tf.keras.Input(shape=(256,256,1))
x = tf.keras.layers.ZeroPadding2D()(inputs)
x = tf.keras.layers.Conv2D(filters=64,
                           kernel_size=(7,7),
                           strides=(2,2),
                           padding='same')(x)
outputs = pretrained_resnet.layers[3](x)
test = tf.keras.Model(inputs, pretrained_resnet.output)

但它给出了这个错误:ValueError: Graph disconnected: cannot obtain value for tensorTensor("input_2:0", .......

我也尝试过使用 tf.keras 顺序 API,但这不起作用,因为 ResNet 不是顺序模型。我基本上只是想用一个新层替换 ResNet50 中的第一个 Conv2D 层。这可能吗?还是我必须重写整个 ResNet 模型?

如有任何建议,我们将不胜感激!

ZeroPadding2DConv2D (7*7, 64, stride 2)Resnet50 网络的 2nd3rd 层。

因此,此处显示仅替换 Resnet50 中的第一层(即输入层)

from tensorflow.keras.applications import ResNet50
import tensorflow as tf

model = ResNet50(include_top = False, weights = 'imagenet')
model.save('model.h5')

res50_model = tf.keras.models.load_model('model.h5')
#res50_model.summary()

要从网络中删除第一层,您可以运行代码如下

 res50_model._layers.pop(0)

Resnet50 expects the input must have 3 channels,因此将输入层形状添加为 (256,256,3) 而不是 (256,256,1).

要添加新的输入层,您可以运行代码如下

newInput = tf.keras.Input(shape=(256,256,3))
newOutputs = res50_model(newInput)
newModel = tf.keras.Model(newInput, newOutputs)
newModel.summary()

输出:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
resnet50 (Model)             multiple                  23587712  
=================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
_________________________________________________________________