如何将预训练网络用作 Tensorflow 中的层?

How do I use a pretrained network as a layer in Tensorflow?

我想使用特征提取器(例如 ResNet101)并在其后添加使用特征提取器层输出的层。但是,我似乎无法弄清楚如何。我只在网上找到了使用整个网络而不添加额外层的解决方案。 我对 Tensorflow 没有经验。

在下面的代码中,您可以看到我的尝试。我可以 运行 在没有额外卷积层的情况下正确地编写代码,但是我的目标是在 ResNet 之后添加更多层。 尝试添加额外的 conv 层时,会返回此类型错误: 类型错误:应为 float32,得到 OrderedDict([('resnet_v1_101/conv1', ...

添加更多层后,我想开始在非常小的测试集上进行训练,看看我的模型是否会过拟合。


import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import matplotlib.pyplot as plt

numclasses = 17

from google.colab import drive
drive.mount('/content/gdrive')

def decode_text(filename):
  img = tf.io.decode_jpeg(tf.io.read_file(filename))
  img = tf.image.resize_bilinear(tf.expand_dims(img, 0), [224, 224])
  img = tf.squeeze(img, 0)
  img.set_shape((None, None, 3))
  return img

dataset = tf.data.TextLineDataset(tf.cast('gdrive/My Drive/5LSM0collab/filenames.txt', tf.string))
dataset = dataset.map(decode_text)
dataset = dataset.batch(2, drop_remainder=True)

img_1 = dataset.make_one_shot_iterator().get_next()
net = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)


sess = tf.Session()

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
img_out, conv_out = sess.run((img_1, net))

resnet_v1.resnet_v1_101 不只是 return net,而是 return 是一个元组 net, end_points。第二个元素是字典,这大概就是您收到此特定错误消息的原因。

对于documentation of this function

Returns:

net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. If global_pool is False, then height_out and width_out are reduced by a factor of output_stride compared to the respective height_in and width_in, else both height_out and width_out equal one. If num_classes is 0 or None, then net is the output of the last ResNet block, potentially after global average pooling. If num_classes a non-zero integer, net contains the pre-softmax activations.

end_points: A dictionary from components of the network to the corresponding activation.

所以你可以这样写:

net, _ = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)

您也可以选择中间层,例如:

_, end_points = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(end_points["main_Scope/resnet_v1_101/block3"], numclasses, 1)

(您可以查看 end_points 以查找端点的名称。您的作用域名称将不同于 main_Scope。)