从张量流模型中获取权重
Get weights from tensorflow model
您好,我想从 tensorflow 微调 VGG 模型。我有两个问题。
如何从网络中获取权重?我的 trainable_variables returns 空列表。
我使用了这里的现有模型:https://github.com/ry/tensorflow-vgg16。
我发现 post 关于获取权重,但这对我不起作用,因为 import_graph_def。
import tensorflow as tf
import PIL.Image
import numpy as np
with open("../vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")
graph = tf.get_default_graph()
cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()
with tf.Session(graph=graph) as sess:
print(tf.trainable_variables() )
sess.run(init)
这 pretrained VGG-16 model encodes all of the model parameters as tf.constant()
ops. (See, for example, the calls to tf.constant()
here.) As a result, the model parameters would not appear in tf.trainable_variables()
, and the model is not mutable without substantial surgery: you would need to replace the constant nodes with tf.Variable
个以相同值开始的对象,以便继续训练。
一般来说,在导入图形进行再训练时,tf.train.import_meta_graph()
function should be used, as this function loads additional metadata (including the collections of variables). The tf.import_graph_def()
函数级别较低,不会填充这些集合。
您好,我想从 tensorflow 微调 VGG 模型。我有两个问题。
如何从网络中获取权重?我的 trainable_variables returns 空列表。
我使用了这里的现有模型:https://github.com/ry/tensorflow-vgg16。
我发现 post 关于获取权重,但这对我不起作用,因为 import_graph_def。
import tensorflow as tf
import PIL.Image
import numpy as np
with open("../vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")
graph = tf.get_default_graph()
cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()
with tf.Session(graph=graph) as sess:
print(tf.trainable_variables() )
sess.run(init)
这 pretrained VGG-16 model encodes all of the model parameters as tf.constant()
ops. (See, for example, the calls to tf.constant()
here.) As a result, the model parameters would not appear in tf.trainable_variables()
, and the model is not mutable without substantial surgery: you would need to replace the constant nodes with tf.Variable
个以相同值开始的对象,以便继续训练。
一般来说,在导入图形进行再训练时,tf.train.import_meta_graph()
function should be used, as this function loads additional metadata (including the collections of variables). The tf.import_graph_def()
函数级别较低,不会填充这些集合。