如何使参数对 SageMaker Tensorflow Endpoint 可用

How to make parameters available to SageMaker Tensorflow Endpoint

我希望使一些超参数可用于 SageMaker 中的服务端点。训练实例可以使用以下超参数访问输入参数:

estimator = TensorFlow(entry_point='autocat.py',
                       role=role,
                       output_path=params['output_path'],
                       code_location=params['code_location'],
                       train_instance_count=1,
                       train_instance_type='ml.c4.xlarge',
                       training_steps=10000,
                       evaluation_steps=None,
                       hyperparameters=params)

但是,在部署端点时,无法在input_fn(serialized_input, content_type)函数中传入用于控制数据处理的参数。

将参数传递给服务实例的最佳方式是什么?? sagemaker.tensorflow.TensorFlow class 中定义的 source_dir 参数是否复制到服务实例?如果是这样,我可以使用 config.yml 或类似的。

Hyper-parameters 用于训练阶段,让您可以调整(Hyper-Parameters 优化 - HPO)您的模型。一旦您有了经过训练的模型,这些 hyper-parameters 就不需要进行推理了。

当您想将功能传递给服务实例时,您通常会在 invoke-endpoint API 调用的每个请求的正文中执行此操作(例如,请参见此处:https://docs.aws.amazon.com/sagemaker/latest/dg/tf-example1-invoke.html) or the call to the predict wrapper in the SageMaker python SDK (https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow). You can see such examples in the sample notebooks (https://github.com/awslabs/amazon-sagemaker-examples/blob/master/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb

是的,一种选择是将您的配置文件添加到 source_dir 并在 input_fn 中加载文件。

另一种选择是使用 serving_input_fn(hyperparameters)。该函数将 TensorFlow 模型转换为 TensorFlow 服务模型。例如:

def serving_input_fn(hyperparameters):

    # gets the input shape from the hyperparameters
    shape = hyperparameters.get('input_shape', [1, 7])

    tensor = tf.placeholder(tf.float32, shape=shape)
    # returns the ServingInputReceiver object.

    return build_raw_serving_input_receiver_fn({INPUT_TENSOR_NAME: tensor})()

啊,我遇到了与您类似的问题,我需要从 S3 下载一些东西以在 input_fn 中用于推理。在我的例子中,它是一本字典。

三个选项:

  1. 使用您的 config.yml 方法,并在任何函数声明之前从入口点文件中下载并导入 s3 文件。这将使 input_fn
  2. 可用
  3. 继续使用超参数方法,下载并导入 serving_input_fn 中的矢量化器,并通过全局变量使其可用,以便 input_fn 可以访问它。
  4. 训练前从 s3 下载文件并直接包含在 source_dir 中。

仅当您不需要在初始训练后单独更改矢量化器时,选项 3 才有效。

无论你做什么,都不要直接在input_fn中下载文件。我犯了那个错误,性能很糟糕,因为每次调用端点都会导致下载 s3 文件。