在导出为 tflite 格式之前修复冻结图的输入节点
Fixing input node of frozen graph, before exporting to tflite format
我可以使用以下方法冻结图表:
freeze_graph.freeze_graph(input_graph=f"{save_graph_path}/graph.pbtxt",
input_saver="",
input_binary=False,
input_checkpoint=last_ckpt,
output_node_names="network/output_node",
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=output_frozen_graph_name,
clear_devices=True,
initializer_nodes="")
但是,该图有两个值得注意的输入节点,即 "input/is_training" 和 "input/input_node"。
我想将此冻结图导出为 tflite 格式,但在这样做时我需要将 is_training 修复为 False(因为它用于 tf.layers.batch_normalization)。
我知道将 is_training 占位符设置为 False 可以解决此问题,但假设我只有冻结的图形文件和检查点,我将如何着手执行此操作?还是不可能?
您只需加载冻结的图表,将有问题的值映射到常量并再次保存图表即可。
import tensorflow as tf
with tf.Graph().as_default():
# Make constant False value (name does not need to match)
is_training = tf.constant(False, dtype=tf.bool, name="input/is_training")
# Load frozen graph
gd = tf.GraphDef()
with open(f"{save_graph_path}/graph.pbtxt", "r") as f:
gd.ParseFromString(f.read())
# Load graph mapping placeholder to constant
tf.import_graph_def(gd, name="", input_map={"input/is_training:0": is_training})
# Save graph again
tf.train.write_graph(tf.get_default_graph(), save_graph_path, "graph_modified.pbtxt",
as_text=True)
我可以使用以下方法冻结图表:
freeze_graph.freeze_graph(input_graph=f"{save_graph_path}/graph.pbtxt",
input_saver="",
input_binary=False,
input_checkpoint=last_ckpt,
output_node_names="network/output_node",
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=output_frozen_graph_name,
clear_devices=True,
initializer_nodes="")
但是,该图有两个值得注意的输入节点,即 "input/is_training" 和 "input/input_node"。
我想将此冻结图导出为 tflite 格式,但在这样做时我需要将 is_training 修复为 False(因为它用于 tf.layers.batch_normalization)。
我知道将 is_training 占位符设置为 False 可以解决此问题,但假设我只有冻结的图形文件和检查点,我将如何着手执行此操作?还是不可能?
您只需加载冻结的图表,将有问题的值映射到常量并再次保存图表即可。
import tensorflow as tf
with tf.Graph().as_default():
# Make constant False value (name does not need to match)
is_training = tf.constant(False, dtype=tf.bool, name="input/is_training")
# Load frozen graph
gd = tf.GraphDef()
with open(f"{save_graph_path}/graph.pbtxt", "r") as f:
gd.ParseFromString(f.read())
# Load graph mapping placeholder to constant
tf.import_graph_def(gd, name="", input_map={"input/is_training:0": is_training})
# Save graph again
tf.train.write_graph(tf.get_default_graph(), save_graph_path, "graph_modified.pbtxt",
as_text=True)