如何在 Tensorflow version='2.0.0' 中保存和加载 tf.estimator.BoostedTreesRegressor 的模型

How to save and load model of tf.estimator.BoostedTreesRegressor in the Tensorflow version='2.0.0'

我是 tf.estimator.BoostedTreesRegressor 的新手。这是我用来构建模型的示例代码。

n_batches = 20

est = tf.estimator.BoostedTreesRegressor(feature_columns,
                                           n_batches_per_layer=n_batches , learning_rate=0.001, n_trees=700,
                                            max_depth=13, 
model_dir = "model", tf.config.threading.set_intra_op_parallelism_threads(60))

est.train(train_input_fn, max_steps=10)

我想保存模型..然后加载模型来预测销量。

你能帮我看看如何在 TensorFlow 版本 2 中做到这一点吗..因为我找不到...

谢谢

你的模型应该按照官方documentation保存在model_dir路径下。请在实例化 BoostedTreesRegressor 时指定 model_dir 的真实目录路径。

此外,您可以使用export_saved_model方法保存模型。

# Saving estimator model
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec(feature_columns))
export_path = estimator.export_saved_model("/dir/path/", serving_input_fn)

要加载已保存的模型,您可以使用 saved_model.load 函数,如下所示:

#loading saved model
imported = tf.saved_model.load(export_path)