如何计算用从 .pb 文件加载的图形定义的张量流模型中可训练参数的总数?
How to count total number of trainable parameters in a tensorflow model defined with graph loaded from .pb file?
我想计算张量流模型中的参数。它与现有问题类似,如下所示。
但是如果模型是用从 .pb 文件加载的图形定义的,则所有建议的答案都不起作用。基本上我用以下函数加载了图表。
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
一个示例是在 tensorflow-for-poets-2 中加载一个 frozen_graph.pb 文件用于再训练目的。
据我了解,GraphDef
没有足够的信息来描述 Variables
。正如 here 所解释的,您将需要 MetaGraph
,其中包含 GraphDef
和 CollectionDef
,后者是可以描述 Variables
的地图。所以下面的代码应该给我们正确的可训练变量计数。
导出元图:
import tensorflow as tf
a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(init)
saver.save(sess, r'.\test')
导入 MetaGraph 并计算可训练参数的总数。
import tensorflow as tf
saver = tf.train.import_meta_graph('test.meta')
with tf.Session() as sess:
saver.restore(sess, 'test')
total_parameters = 0
for variable in tf.trainable_variables():
total_parameters += 1
print(total_parameters)
我想计算张量流模型中的参数。它与现有问题类似,如下所示。
但是如果模型是用从 .pb 文件加载的图形定义的,则所有建议的答案都不起作用。基本上我用以下函数加载了图表。
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
一个示例是在 tensorflow-for-poets-2 中加载一个 frozen_graph.pb 文件用于再训练目的。
据我了解,GraphDef
没有足够的信息来描述 Variables
。正如 here 所解释的,您将需要 MetaGraph
,其中包含 GraphDef
和 CollectionDef
,后者是可以描述 Variables
的地图。所以下面的代码应该给我们正确的可训练变量计数。
导出元图:
import tensorflow as tf
a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(init)
saver.save(sess, r'.\test')
导入 MetaGraph 并计算可训练参数的总数。
import tensorflow as tf
saver = tf.train.import_meta_graph('test.meta')
with tf.Session() as sess:
saver.restore(sess, 'test')
total_parameters = 0
for variable in tf.trainable_variables():
total_parameters += 1
print(total_parameters)