tensorflow.contrib.learn.ExportStrategy 的示例
Example of tensorflow.contrib.learn.ExportStrategy
有人可以提供 Tensorflow 的完整工作代码示例吗
tf.contrib.learn.ExportStrategy
文档缺少示例。对于这个看似晦涩的 Tensorflow 操作,我在 Github 或 Whosebug 上也找不到任何示例。
文档:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ExportStrategy
Google CloudML 在这里有一个很好的工作示例:
https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/census/customestimator/trainer
您需要使用他们的完整代码才能使示例正常工作,但这里是如何使用 ExportStrategy 的要点:
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.utils import (
saved_model_export_utils)
from tensorflow.contrib.training.python.training import hparam
def csv_serving_input_fn():
"""Build the serving inputs."""
csv_row = tf.placeholder(
shape=[None],
dtype=tf.string
)
features = parse_csv(csv_row)
# Ignore label column
features.pop(LABEL_COLUMN)
return tf.estimator.export.ServingInputReceiver(
features, {'csv_row': csv_row})
def example_serving_input_fn():
"""Build the serving inputs."""
example_bytestring = tf.placeholder(
shape=[None],
dtype=tf.string,
)
features = tf.parse_example(
example_bytestring,
tf.feature_column.make_parse_example_spec(INPUT_COLUMNS)
)
return tf.estimator.export.ServingInputReceiver(
features, {'example_proto': example_bytestring})
def json_serving_input_fn():
"""Build the serving inputs."""
inputs = {}
for feat in INPUT_COLUMNS:
inputs[feat.name] = tf.placeholder(shape=[None], dtype=feat.dtype)
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
SERVING_FUNCTIONS = {
'JSON': json_serving_input_fn,
'EXAMPLE': example_serving_input_fn,
'CSV': csv_serving_input_fn
}
# Run the training job
# learn_runner pulls configuration information from environment
# variables using tf.learn.RunConfig and uses this configuration
# to conditionally execute Experiment, or param server code
learn_runner.run(
generate_experiment_fn(
min_eval_frequency=args.min_eval_frequency,
eval_delay_secs=args.eval_delay_secs,
train_steps=args.train_steps,
eval_steps=args.eval_steps,
export_strategies=[saved_model_export_utils.make_export_strategy(
SERVING_FUNCTIONS[args.export_format],
exports_to_keep=1
)]
),
run_config=tf.contrib.learn.RunConfig(model_dir=args.job_dir),
hparams=hparam.HParams(**args.__dict__)
)
有人可以提供 Tensorflow 的完整工作代码示例吗
tf.contrib.learn.ExportStrategy
文档缺少示例。对于这个看似晦涩的 Tensorflow 操作,我在 Github 或 Whosebug 上也找不到任何示例。
文档:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ExportStrategy
Google CloudML 在这里有一个很好的工作示例: https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/census/customestimator/trainer
您需要使用他们的完整代码才能使示例正常工作,但这里是如何使用 ExportStrategy 的要点:
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.utils import (
saved_model_export_utils)
from tensorflow.contrib.training.python.training import hparam
def csv_serving_input_fn():
"""Build the serving inputs."""
csv_row = tf.placeholder(
shape=[None],
dtype=tf.string
)
features = parse_csv(csv_row)
# Ignore label column
features.pop(LABEL_COLUMN)
return tf.estimator.export.ServingInputReceiver(
features, {'csv_row': csv_row})
def example_serving_input_fn():
"""Build the serving inputs."""
example_bytestring = tf.placeholder(
shape=[None],
dtype=tf.string,
)
features = tf.parse_example(
example_bytestring,
tf.feature_column.make_parse_example_spec(INPUT_COLUMNS)
)
return tf.estimator.export.ServingInputReceiver(
features, {'example_proto': example_bytestring})
def json_serving_input_fn():
"""Build the serving inputs."""
inputs = {}
for feat in INPUT_COLUMNS:
inputs[feat.name] = tf.placeholder(shape=[None], dtype=feat.dtype)
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
SERVING_FUNCTIONS = {
'JSON': json_serving_input_fn,
'EXAMPLE': example_serving_input_fn,
'CSV': csv_serving_input_fn
}
# Run the training job
# learn_runner pulls configuration information from environment
# variables using tf.learn.RunConfig and uses this configuration
# to conditionally execute Experiment, or param server code
learn_runner.run(
generate_experiment_fn(
min_eval_frequency=args.min_eval_frequency,
eval_delay_secs=args.eval_delay_secs,
train_steps=args.train_steps,
eval_steps=args.eval_steps,
export_strategies=[saved_model_export_utils.make_export_strategy(
SERVING_FUNCTIONS[args.export_format],
exports_to_keep=1
)]
),
run_config=tf.contrib.learn.RunConfig(model_dir=args.job_dir),
hparams=hparam.HParams(**args.__dict__)
)