SageMaker Managed Spot Training with Object Detection 算法

SageMaker Managed Spot Training with Object Detection algorithm

我正在尝试使用新的 Managed Spot Training 功能从现有模型开始训练对象检测模型,创建 Estimator 时使用的参数如下:

od_model = sagemaker.estimator.Estimator(get_image_uri(sagemaker.Session().boto_region_name, 'object-detection', repo_version="latest"),
                                         Config['role'],
                                         train_instance_count = 1,
                                         train_instance_type = 'ml.p3.16xlarge',
                                         train_volume_size = 50,
                                         train_max_run = (48 * 60 * 60),
                                         train_use_spot_instances = True,
                                         train_max_wait = (72 * 60 * 60),
                                         input_mode = 'File',
                                         checkpoint_s3_uri = Config['train_checkpoint_uri'],
                                         output_path = Config['s3_output_location'],
                                         sagemaker_session = sagemaker.Session()
                                         )

(上面对 Config 的引用是我用于 extract/centralise 一些参数的配置数据结构)

当我 运行 以上时,我得到以下异常:

botocore.exceptions.ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: MaxWaitTimeInSeconds above 3600 is not supported for the given algorithm.

如果我将 train_max_wait 更改为 3600,则会出现此异常:

botocore.exceptions.ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: Invalid MaxWaitTimeInSeconds. It must be present and be greater than or equal to MaxRuntimeInSeconds

然而,将 max_run_time 更改为 3600 或更少对我来说是行不通的,因为我预计这个模型需要几天的时间来训练(大数据集),事实上,一个 epoch 需要的时间超过一个小时。

AWS blog post on Managed Spot TrainingMaxWaitTimeInSeconds 限制为 60 分钟:

For built-in algorithms and AWS Marketplace algorithms that don’t use checkpointing, we’re enforcing a maximum training time of 60 minutes (MaxWaitTimeInSeconds parameter).

早些时候,同一博客 post 说:

Built-in algorithms: computer vision algorithms support checkpointing (Object Detection, Semantic Segmentation, and very soon Image Classification).

所以我不认为是我的算法不支持检查点。事实上,该博客 post 使用对象检测和最多 运行 次 48 小时。所以我不认为这是算法限制。

正如您在上面看到的,我已经为检查点设置了一个 S3 URL。 S3 bucket 确实存在,训练容器可以访问它(这是放置训练数据和模型输出的同一个 bucket,在打开 spot training 之前我访问这些没有问题。

我的 boto 和 sagemaker 库是当前版本:

boto3 (1.9.239)
botocore (1.12.239)
sagemaker (1.42.3)

据我阅读各种文档的了解,我已经正确设置了所有内容。我的用例几乎与上面链接的博客 post 中描述的完全相同,但我使用的是 SageMaker Python SDK 而不是控制台。

我真的很想尝试 Managed Spot Training 以节省一些钱,因为我即将进行很长的培训 运行。但是将超时限制为一个小时并不适用于我的用例。有什么建议吗?

更新: 如果我注释掉 train_use_spot_instancestrain_max_wait 选项以在常规按需实例上进行训练,我的训练作业将成功创建。如果我随后尝试使用控制台克隆作业并在克隆上打开 Spot 实例,我会得到相同的 ValidationException。

我今天又 运行 我的脚本,它工作正常,没有 botocore.exceptions.ClientError 异常。鉴于此问题影响了 Sagemaker 的 Python SDK 和控制台,我怀疑这可能是后端 API 而不是我的客户端代码的问题。

不管怎样,现在都可以了。