从字符串解析 graph_def 时出错
Error when parsing graph_def from string
我正在尝试 运行 将 Tensorflow 图非常简单地保存为 .pb 文件,但在解析它时出现此错误:
Traceback (most recent call last):
File "test_import_stripped_bm.py", line 28, in <module>
graph_def.ParseFromString(fileContent)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1069, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 743, in DecodeMap
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1095, in InternalParse
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 799, in _SkipGroup
new_pos = SkipField(buffer, pos, end, tag_bytes)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 814, in _SkipFixed32
raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.
这是我用来将其写入 .pb 的代码:
import tensorflow as tf
builder = tf.saved_model.builder.SavedModelBuilder('models/TEST-3')
w1 = tf.Variable(tf.random_normal((2,2)), name="w1")
w2 = tf.Variable(tf.random_normal((2,2)), name="w2")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], clear_devices = True)
builder.save()
sess.close()
这是解析它的代码:
import tensorflow as tf
import os
model_path = os.path.join('models/TEST-3', 'saved_model.pb')
with open(model_path, mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
要查看我必须执行的确切错误
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
在 运行 之前。
我也在 python 2 和 3 上用不同的 tensorflow 版本试过这个,我 运行ning 在 Ubuntu 16.04 上。在使用 tensorflow 0.9.0rc0 的 python 2.7 上,我设法得到了一个稍微不同的错误:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 489, in DecodeRepeatedField
value.append(_ConvertToUnicode(buffer[pos:new_pos]))
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 469, in _ConvertToUnicode
return local_unicode(byte_str, 'utf-8')
UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 18: 'utf8' codec can't decode byte 0x80 in position 18: invalid start byte in field: tensorflow.FunctionDef.Node.ret
我可以使用此代码解析其他 .pb 图,例如这个 https://github.com/taey16/tf/blob/master/imagenet/classify_image_graph_def.pb
提前致谢。
这里的问题是您正在尝试解析 SavedModel
protocol buffer as if it were a GraphDef
. Although a SavedModel
contains GraphDef
, they have different binary formats. The following code, using tf.saved_model.loader.load()
应该工作:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3")
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
这里的fileContent应该是一个**"Frozen Graph"。 Tensorflow 也提供了一个api,参考Tensorflow freeze_graph API
另一种创建冻结图的方法是:
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(<.meta file>)
saver.restore(sess, <checkpoint>)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
[comma separated output nodes name]
)
# Saving "output_graph_def " in a file and generate frozen graph.
with tf.gfile.GFile('frozen_graph.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
使用frozen_graph.pb作为
graph_def.ParseFromString("frozen_graph.pb")
所以首先使用Saver 对象生成.meta 和其他文件。完成后创建冻结图。
我正在尝试 运行 将 Tensorflow 图非常简单地保存为 .pb 文件,但在解析它时出现此错误:
Traceback (most recent call last):
File "test_import_stripped_bm.py", line 28, in <module>
graph_def.ParseFromString(fileContent)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1069, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 743, in DecodeMap
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1095, in InternalParse
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 799, in _SkipGroup
new_pos = SkipField(buffer, pos, end, tag_bytes)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 814, in _SkipFixed32
raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.
这是我用来将其写入 .pb 的代码:
import tensorflow as tf
builder = tf.saved_model.builder.SavedModelBuilder('models/TEST-3')
w1 = tf.Variable(tf.random_normal((2,2)), name="w1")
w2 = tf.Variable(tf.random_normal((2,2)), name="w2")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], clear_devices = True)
builder.save()
sess.close()
这是解析它的代码:
import tensorflow as tf
import os
model_path = os.path.join('models/TEST-3', 'saved_model.pb')
with open(model_path, mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
要查看我必须执行的确切错误
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
在 运行 之前。 我也在 python 2 和 3 上用不同的 tensorflow 版本试过这个,我 运行ning 在 Ubuntu 16.04 上。在使用 tensorflow 0.9.0rc0 的 python 2.7 上,我设法得到了一个稍微不同的错误:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
if value._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
pos = field_decoder(buffer, new_pos, end, self, field_dict)
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 489, in DecodeRepeatedField
value.append(_ConvertToUnicode(buffer[pos:new_pos]))
File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 469, in _ConvertToUnicode
return local_unicode(byte_str, 'utf-8')
UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 18: 'utf8' codec can't decode byte 0x80 in position 18: invalid start byte in field: tensorflow.FunctionDef.Node.ret
我可以使用此代码解析其他 .pb 图,例如这个 https://github.com/taey16/tf/blob/master/imagenet/classify_image_graph_def.pb
提前致谢。
这里的问题是您正在尝试解析 SavedModel
protocol buffer as if it were a GraphDef
. Although a SavedModel
contains GraphDef
, they have different binary formats. The following code, using tf.saved_model.loader.load()
应该工作:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3")
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
这里的fileContent应该是一个**"Frozen Graph"。 Tensorflow 也提供了一个api,参考Tensorflow freeze_graph API
另一种创建冻结图的方法是:
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(<.meta file>)
saver.restore(sess, <checkpoint>)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
[comma separated output nodes name]
)
# Saving "output_graph_def " in a file and generate frozen graph.
with tf.gfile.GFile('frozen_graph.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
使用frozen_graph.pb作为
graph_def.ParseFromString("frozen_graph.pb")
所以首先使用Saver 对象生成.meta 和其他文件。完成后创建冻结图。