在 tensorflow 上加载预训练的 vgg-16
loading pre-trained vgg-16 on tensorflow
我正在尝试使用 tensorflow r1.1 加载预训练的 vgg-16 网络。该网络在 3 个文件中提供:
- saved_model.pb
- variables/variables.index
- variables/variables.data-00000-of-00001
将变量 sess
初始化为 tf.Session()
后
我使用以下脚本加载网络并提取一些特定层:
vgg_path='./'
model_filename = os.path.join(vgg_path, "saved_model.pb")
export_dir = os.path.join(vgg_path, "variables/")
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
image_input, l7, l4, l3 = tf.import_graph_def(sm.meta_graphs[0].graph_def,
name='',return_elements=["image_input:0", "layer7_out:0",
"layer4_out:0", "layer3_out:0"])
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, image_input)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l7)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l4)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l3)
saver = tf.train.Saver(tf.global_variables())
print("load data")
saver.restore(sess, export_dir)
脚本在初始化变量 saver
时因以下错误而终止:
TypeError: Variable to save is not a Variable: Tensor("image_input:0",
shape=(?, ?, ?, 3), dtype=float32)
如何修复我的脚本并恢复预训练的 vgg 网络?
因为你有一个 SavedModel, you can use tf.saved_model.loader 用于加载它:
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ["some_tag"], model_dir)
我正在尝试使用 tensorflow r1.1 加载预训练的 vgg-16 网络。该网络在 3 个文件中提供:
- saved_model.pb
- variables/variables.index
- variables/variables.data-00000-of-00001
将变量 sess
初始化为 tf.Session()
我使用以下脚本加载网络并提取一些特定层:
vgg_path='./'
model_filename = os.path.join(vgg_path, "saved_model.pb")
export_dir = os.path.join(vgg_path, "variables/")
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
image_input, l7, l4, l3 = tf.import_graph_def(sm.meta_graphs[0].graph_def,
name='',return_elements=["image_input:0", "layer7_out:0",
"layer4_out:0", "layer3_out:0"])
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, image_input)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l7)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l4)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l3)
saver = tf.train.Saver(tf.global_variables())
print("load data")
saver.restore(sess, export_dir)
脚本在初始化变量 saver
时因以下错误而终止:
TypeError: Variable to save is not a Variable: Tensor("image_input:0", shape=(?, ?, ?, 3), dtype=float32)
如何修复我的脚本并恢复预训练的 vgg 网络?
因为你有一个 SavedModel, you can use tf.saved_model.loader 用于加载它:
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ["some_tag"], model_dir)