从 SageMaker 脚本处理器获取变量

Get Variable from SageMaker Script Processor

我正在使用 SageMaker 进行分布式 TensorFlow 模型训练和服务。我正在尝试从 ScriptProcessor 获取预处理数据集的形状,以便我可以将其提供给 TensorFlow 环境。

script_processor = ScriptProcessor(command=['python3'],
                image_uri=preprocess_img_uri,
                role=role,
                instance_count=1,
                sagemaker_session=sm_session,
                instance_type=preprocess_instance_type)

script_processor.run(code=preprocess_script_uri,
                inputs=[ProcessingInput(
                        source=source_dir + username + '/' + dataset_name,
                        destination='/opt/ml/processing/input')],
                outputs=[
                        ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
                        ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test")
                ],

                arguments = ['--filepath', dataset_name, '--labels', 'labels', '--test_size', '0.2', '--shuffle', 'False', '--lookback', '5',])

preprocessing_job_description = script_processor.jobs[-1].describe()

output_config = preprocessing_job_description["ProcessingOutputConfig"]
for output in output_config["Outputs"]:
    if output["OutputName"] == "train_data":
        preprocessed_training_data = output["S3Output"]["S3Uri"]
    if output["OutputName"] == "test_data":
        preprocessed_test_data = output["S3Output"]["S3Uri"]

我想获取以下数据:

pre_processed_train_data_shape = script_processor.train_data_shape?

我只是不确定如何从 docker 容器中获取值。我在这里查看了文档:https://sagemaker.readthedocs.io/en/stable/api/training/processing.html

有几个选项:

  1. 将一些数据写入位于 /opt/ml/output/message 的文本文件,然后调用 DescribeProcessingJob(使用 Boto3 或 AWS CLI 或 API)并检索 ExitMessage

    aws sagemaker describe-processing-job \
      --processing-job-name foo \
      --output text \
      --query ExitMessage
    
  2. 向您的处理作业添加新输出并向其发送数据

  3. 如果您的 train_data 是 CSV、JSON 或 Parquet,则在 train_data 上使用 S3 Select query,因为它是 [=38] =]

    aws s3api select-object-content \
      --bucket foo \
      --key 'path/to/train_data.csv' \
      --expression "SELECT count(*) FROM s3object" \
      --expression-type 'SQL' \
      --input-serialization '{"CSV": {}}' \
      --output-serialization '{"CSV": {}}' /dev/stdout
    

expression 设置为 select * from s3object limit 1 以获取列