来自 sess.run(() 的模型权重以字节为单位返回值。如何更改为值?
Model weights from sess.run(() is returning the value in bytes. How can I change to value?
我正在尝试从 .pb
文件中保存的模型中提取模型权重。但是,当我 运行 sess 它 returns 模型权重以字节为单位时,我无法读取它。我的代码如下:
constant_values = {}
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
meta_graph = tf.compat.v1.saved_model.loader.load(sess,[tf.compat.v1.saved_model.tag_constants.SERVING],'model_2/1/')
tf.import_graph_def(meta_graph.graph_def, name='')
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
x=0
for constant_op in constant_ops:
x = constant_op.outputs[0]
value = sess.run(constant_op.outputs[0])
constant_values[constant_op.name] = valu
print(constant_op.name, value)
这里有一段么returns:
b'\n\x1b\n\t\x08\x01\x12\x05model\n\x0e\x08\x02\x12\nsignatures\n\xe2\x01\n\x18\x08\x03\x12\x14layer_with_weights-0\n\x0b\x08\x03\x12\x07layer-0\n\x0b\x08\x04\x12\x07layer-1\n\x18\x08\x05\x12\x14layer_with_weights-1\n\x0b\x08\x05\x12\x07layer-2\n\r\x08\x06\x12\tvariables\n\x17\x08\x07\x12\x13trainable_variables\n\x19\x08\x08\x12\x15regularization_losses\n\r\x08\t\x12\tkeras_api\n\x0e\x08\n\x12\nsignatures\n#\x08\x0b\x12\x1f_self_saveable_object_factories\n\x00\n\x92R\n\x0b\x08\x0c\x12\x07layer-0\n\x0b\x08\r\x12\x07layer-1\n\x18\x08\x0e\x12\x14layer_with_weights-0\n\x0b\x08\x0e\x12\x07layer-2\n\x0b\x08\x0f\x12\x07layer-3\n\x18\x08\x10\x12\x14layer_with_weights-1\n\x0b\x08\x10\x12\x07layer-4\n\x18\x08\x11\x12\x14layer_with_weights-2...
谢谢
您确定图表中具有模型权重的常量变量名为 'Const' 吗?
如果您只是从有关如何在其他地方获取模型权重的教程中复制此代码(正如我过去所见),请尝试以下操作:
而不是 constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
尝试 constant_ops = [op for op in sess.graph.get_operations()]
并查看图中所有张量和操作的样子。您可能会发现权重节点的命名方式不同。
最佳,
我正在尝试从 .pb
文件中保存的模型中提取模型权重。但是,当我 运行 sess 它 returns 模型权重以字节为单位时,我无法读取它。我的代码如下:
constant_values = {}
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
meta_graph = tf.compat.v1.saved_model.loader.load(sess,[tf.compat.v1.saved_model.tag_constants.SERVING],'model_2/1/')
tf.import_graph_def(meta_graph.graph_def, name='')
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
x=0
for constant_op in constant_ops:
x = constant_op.outputs[0]
value = sess.run(constant_op.outputs[0])
constant_values[constant_op.name] = valu
print(constant_op.name, value)
这里有一段么returns:
b'\n\x1b\n\t\x08\x01\x12\x05model\n\x0e\x08\x02\x12\nsignatures\n\xe2\x01\n\x18\x08\x03\x12\x14layer_with_weights-0\n\x0b\x08\x03\x12\x07layer-0\n\x0b\x08\x04\x12\x07layer-1\n\x18\x08\x05\x12\x14layer_with_weights-1\n\x0b\x08\x05\x12\x07layer-2\n\r\x08\x06\x12\tvariables\n\x17\x08\x07\x12\x13trainable_variables\n\x19\x08\x08\x12\x15regularization_losses\n\r\x08\t\x12\tkeras_api\n\x0e\x08\n\x12\nsignatures\n#\x08\x0b\x12\x1f_self_saveable_object_factories\n\x00\n\x92R\n\x0b\x08\x0c\x12\x07layer-0\n\x0b\x08\r\x12\x07layer-1\n\x18\x08\x0e\x12\x14layer_with_weights-0\n\x0b\x08\x0e\x12\x07layer-2\n\x0b\x08\x0f\x12\x07layer-3\n\x18\x08\x10\x12\x14layer_with_weights-1\n\x0b\x08\x10\x12\x07layer-4\n\x18\x08\x11\x12\x14layer_with_weights-2...
谢谢
您确定图表中具有模型权重的常量变量名为 'Const' 吗?
如果您只是从有关如何在其他地方获取模型权重的教程中复制此代码(正如我过去所见),请尝试以下操作:
而不是 constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
尝试 constant_ops = [op for op in sess.graph.get_operations()]
并查看图中所有张量和操作的样子。您可能会发现权重节点的命名方式不同。
最佳,