saved_model.prune() 在 TF2.0 中
saved_model.prune() in TF2.0
我正在尝试修剪使用 tf.keras 生成的 SavedModel
的节点。剪枝脚本如下:
svmod = tf.saved_model.load(fn) #version 1
#svmod = tfk.experimental.load_from_saved_model(fn) #version 2
feeds = ['foo:0']
fetches = ['bar:0']
svmod2 = svmod.prune(feeds=feeds, fetches=fetches)
tf.saved_model.save(svmod2, '/tmp/saved_model/') #version 1
#tfk.experimental.export_saved_model(svmod2, '/tmp/saved_model/') #version 2
如果我使用版本 #1 修剪有效但在保存时给出 ValueError: Expected a Trackable object for export
。在版本 2 中,没有 prune() 方法。
如何修剪 TF2.0 Keras SavedModel?
由于您可以在版本 #1 中成功修剪,我建议您尝试 'pickle' 来保存模型。
尝试以下步骤来保存模型。
import pickle
with open('<model_name.pkl>', 'wb') as f:
pickle.dump(<your_model>, f)
读模型为:
with open('<model_name.pkl>', 'rb') as f:
model = pickle.load(f)
在您的情况下,对于版本 #1,代码段中的 your_model 是 svmod2.
看起来您在版本 1 中修剪模型的方式没问题;根据你的错误信息,修剪后的模型无法保存,因为它不是"trackable",这是保存tf.saved_model.save
模型的必要条件。制作可追踪对象的一种方法是继承tf.Module
class, as described in the guides for using the SavedModel format and concrete functions. Below is an example of trying to save a tf.function
对象(因为对象不可追踪而失败),继承自tf.module
,并保存结果对象:
(使用 Python 版本 3.7.6、TensorFlow 版本 2.1.0 和 NumPy 版本 1.18.1)
import tensorflow as tf, numpy as np
# Define a random TensorFlow function and generate a reference output
conv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)
@tf.function
def conv_model(x):
return tf.nn.conv2d(x, conv_filter, 1, "SAME")
input_tensor = tf.ones([1, 2, 3, 4])
output_tensor = conv_model(input_tensor)
print("Original model outputs:", output_tensor, sep="\n")
# Try saving the model: it won't work because a tf.function is not trackable
export_dir = "./tmp/"
try: tf.saved_model.save(conv_model, export_dir)
except ValueError: print(
"Can't save {} object because it's not trackable".format(type(conv_model)))
# Now define a trackable object by inheriting from the tf.Module class
class MyModule(tf.Module):
@tf.function
def __call__(self, x): return conv_model(x)
# Instantiate the trackable object, and call once to trace-compile a graph
module_func = MyModule()
module_func(input_tensor)
tf.saved_model.save(module_func, export_dir)
# Restore the model and verify that the outputs are consistent
restored_model = tf.saved_model.load(export_dir)
restored_output_tensor = restored_model(input_tensor)
print("Restored model outputs:", restored_output_tensor, sep="\n")
if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):
print("Outputs are consistent :)")
else: print("Outputs are NOT consistent :(")
控制台输出:
Original model outputs:
tf.Tensor(
[[[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]
[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Can't save <class 'tensorflow.python.eager.def_function.Function'> object
because it's not trackable
Restored model outputs:
tf.Tensor(
[[[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]
[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Outputs are consistent :)
因此您应该尝试如下修改您的代码:
svmod = tf.saved_model.load(fn) #version 1
svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])
class Exportable(tf.Module):
@tf.function
def __call__(self, model_inputs): return svmod2(model_inputs)
svmod2_export = Exportable()
svmod2_export(typical_input) # call once with typical input to trace-compile
tf.saved_model.save(svmod2_export, '/tmp/saved_model/')
如果您不想从 tf.Module
继承,您也可以只实例化一个 tf.Module
对象并通过替换该部分来添加 tf.function
method/callable 属性代码如下:
to_export = tf.Module()
to_export.call = tf.function(conv_model)
to_export.call(input_tensor)
tf.saved_model.save(to_export, export_dir)
restored_module = tf.saved_model.load(export_dir)
restored_func = restored_module.call
我正在尝试修剪使用 tf.keras 生成的 SavedModel
的节点。剪枝脚本如下:
svmod = tf.saved_model.load(fn) #version 1
#svmod = tfk.experimental.load_from_saved_model(fn) #version 2
feeds = ['foo:0']
fetches = ['bar:0']
svmod2 = svmod.prune(feeds=feeds, fetches=fetches)
tf.saved_model.save(svmod2, '/tmp/saved_model/') #version 1
#tfk.experimental.export_saved_model(svmod2, '/tmp/saved_model/') #version 2
如果我使用版本 #1 修剪有效但在保存时给出 ValueError: Expected a Trackable object for export
。在版本 2 中,没有 prune() 方法。
如何修剪 TF2.0 Keras SavedModel?
由于您可以在版本 #1 中成功修剪,我建议您尝试 'pickle' 来保存模型。 尝试以下步骤来保存模型。
import pickle
with open('<model_name.pkl>', 'wb') as f:
pickle.dump(<your_model>, f)
读模型为:
with open('<model_name.pkl>', 'rb') as f:
model = pickle.load(f)
在您的情况下,对于版本 #1,代码段中的 your_model 是 svmod2.
看起来您在版本 1 中修剪模型的方式没问题;根据你的错误信息,修剪后的模型无法保存,因为它不是"trackable",这是保存tf.saved_model.save
模型的必要条件。制作可追踪对象的一种方法是继承tf.Module
class, as described in the guides for using the SavedModel format and concrete functions. Below is an example of trying to save a tf.function
对象(因为对象不可追踪而失败),继承自tf.module
,并保存结果对象:
(使用 Python 版本 3.7.6、TensorFlow 版本 2.1.0 和 NumPy 版本 1.18.1)
import tensorflow as tf, numpy as np
# Define a random TensorFlow function and generate a reference output
conv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)
@tf.function
def conv_model(x):
return tf.nn.conv2d(x, conv_filter, 1, "SAME")
input_tensor = tf.ones([1, 2, 3, 4])
output_tensor = conv_model(input_tensor)
print("Original model outputs:", output_tensor, sep="\n")
# Try saving the model: it won't work because a tf.function is not trackable
export_dir = "./tmp/"
try: tf.saved_model.save(conv_model, export_dir)
except ValueError: print(
"Can't save {} object because it's not trackable".format(type(conv_model)))
# Now define a trackable object by inheriting from the tf.Module class
class MyModule(tf.Module):
@tf.function
def __call__(self, x): return conv_model(x)
# Instantiate the trackable object, and call once to trace-compile a graph
module_func = MyModule()
module_func(input_tensor)
tf.saved_model.save(module_func, export_dir)
# Restore the model and verify that the outputs are consistent
restored_model = tf.saved_model.load(export_dir)
restored_output_tensor = restored_model(input_tensor)
print("Restored model outputs:", restored_output_tensor, sep="\n")
if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):
print("Outputs are consistent :)")
else: print("Outputs are NOT consistent :(")
控制台输出:
Original model outputs:
tf.Tensor(
[[[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]
[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Can't save <class 'tensorflow.python.eager.def_function.Function'> object
because it's not trackable
Restored model outputs:
tf.Tensor(
[[[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]
[[-2.3629642 1.2904963 ]
[-2.3629642 1.2904963 ]
[-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)
Outputs are consistent :)
因此您应该尝试如下修改您的代码:
svmod = tf.saved_model.load(fn) #version 1
svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])
class Exportable(tf.Module):
@tf.function
def __call__(self, model_inputs): return svmod2(model_inputs)
svmod2_export = Exportable()
svmod2_export(typical_input) # call once with typical input to trace-compile
tf.saved_model.save(svmod2_export, '/tmp/saved_model/')
如果您不想从 tf.Module
继承,您也可以只实例化一个 tf.Module
对象并通过替换该部分来添加 tf.function
method/callable 属性代码如下:
to_export = tf.Module()
to_export.call = tf.function(conv_model)
to_export.call(input_tensor)
tf.saved_model.save(to_export, export_dir)
restored_module = tf.saved_model.load(export_dir)
restored_func = restored_module.call