使用指数移动平均冻结 TensorFlow 模型给出不同的推断概率
Freezing TensorFlow models with Exponential Moving Average gives different inferred probabilities
我正在尝试冻结基于 inception-v3 的模型和 运行 推理。但是,与原始模型相比,我使用冻结模型得到的推断概率不一致。
我发现训练和推理中的差异来自指数移动平均线 (EMA)。当我在两个模型中关闭 EMA 时,我得到相同的概率(差异 < 1e-5)。
我使用的冻结代码:
from __future__ import print_function
import tensorflow as tf
from nets.inception_v3 import inception_v3, inception_v3_arg_scope
from tensorflow.python.framework import graph_util
import sys
slim = tf.contrib.slim
checkpoint_file = '/my/model'
with tf.Graph().as_default() as graph:
images = tf.placeholder(shape=[None, 100, 221, 6], dtype=tf.float32, name = 'input')
with slim.arg_scope(inception_v3_arg_scope()):
logits, end_points = inception_v3(images, num_classes = 3, create_aux_logits = False, is_training = False)
variables_to_restore = slim.get_variables_to_restore()
MOVING_AVERAGE_DECAY = 0.9999
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY)
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
variables_to_restore = variable_averages.variables_to_restore() #This line is commented if EMA is turned off
saver = tf.train.Saver(variables_to_restore)
#Setup graph def
input_graph_def = graph.as_graph_def()
output_node_names = "InceptionV3/Predictions/Reshape_1"
output_graph_name = "./frozen_inception_v3_new_100_221_ema.pb"
with tf.Session() as sess:
saver.restore(sess, checkpoint_file)
#Exporting the graph
print ("Exporting graph...")
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph_name, "wb") as f:
f.write(output_graph_def.SerializeToString())
EMA部分与原模型代码相同
我是不是错误地冻结了 EMA 推理图?
问题已解决。
我使用的 EMA 部分
MOVING_AVERAGE_DECAY = 0.9999
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY)
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
variables_to_restore = variable_averages.variables_to_restore()
不正确。
如果我删除
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
现在效果不错
我正在尝试冻结基于 inception-v3 的模型和 运行 推理。但是,与原始模型相比,我使用冻结模型得到的推断概率不一致。
我发现训练和推理中的差异来自指数移动平均线 (EMA)。当我在两个模型中关闭 EMA 时,我得到相同的概率(差异 < 1e-5)。
我使用的冻结代码:
from __future__ import print_function
import tensorflow as tf
from nets.inception_v3 import inception_v3, inception_v3_arg_scope
from tensorflow.python.framework import graph_util
import sys
slim = tf.contrib.slim
checkpoint_file = '/my/model'
with tf.Graph().as_default() as graph:
images = tf.placeholder(shape=[None, 100, 221, 6], dtype=tf.float32, name = 'input')
with slim.arg_scope(inception_v3_arg_scope()):
logits, end_points = inception_v3(images, num_classes = 3, create_aux_logits = False, is_training = False)
variables_to_restore = slim.get_variables_to_restore()
MOVING_AVERAGE_DECAY = 0.9999
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY)
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
variables_to_restore = variable_averages.variables_to_restore() #This line is commented if EMA is turned off
saver = tf.train.Saver(variables_to_restore)
#Setup graph def
input_graph_def = graph.as_graph_def()
output_node_names = "InceptionV3/Predictions/Reshape_1"
output_graph_name = "./frozen_inception_v3_new_100_221_ema.pb"
with tf.Session() as sess:
saver.restore(sess, checkpoint_file)
#Exporting the graph
print ("Exporting graph...")
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph_name, "wb") as f:
f.write(output_graph_def.SerializeToString())
EMA部分与原模型代码相同
我是不是错误地冻结了 EMA 推理图?
问题已解决。 我使用的 EMA 部分
MOVING_AVERAGE_DECAY = 0.9999
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY)
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
variables_to_restore = variable_averages.variables_to_restore()
不正确。 如果我删除
for var in variables_to_restore:
tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
现在效果不错