如何在 Sagemaker 2 中使用序列化器和反序列化器

How to use Serializer and Deserializer in Sagemaker 2

我使用 conda_python3 内核启动了一个 Sagemaker notebook,并遵循 example Notebook for Random Cut Forest.

在撰写本文时,conda_python3 附带的 Sagemaker SDK 是 1.72.0 版本,但我想使用新功能,所以我将笔记本更新为使用最新版本

%%bash
pip install -U sagemaker

我看到它更新了。

print(sagemaker.__version__)

# 2.4.1

从版本 1.x 到 2.x 的更改是 serializer/deserializer classes

以前(在版本 1.72.0 中)我会更新我的预测器以使用正确的 serializer/deserializer,并且可以 运行 推断我的模型

from sagemaker.predictor import csv_serializer, json_deserializer


rcf_inference = rcf.deploy(
    initial_instance_count=1,
    instance_type='ml.m4.xlarge',
)

rcf_inference.content_type = 'text/csv'
rcf_inference.serializer = csv_serializer
rcf_inference.accept = 'application/json'
rcf_inference.deserializer = json_deserializer

results = rcf_inference.predict(some_numpy_array)

(注意这一切都来自 example

我尝试像这样使用 sagemaker 2.4.1 复制它

from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import CSVSerializer

rcf_inference = rcf.deploy(
    initial_instance_count=1,
    instance_type='ml.m5.xlarge',
    serializer=CSVSerializer,
    deserializer=JSONDeserializer
)

results = rcf_inference.predict(some_numpy_array)

然后我收到一个错误

TypeError: serialize() missing 1 required positional argument: 'data'

我知道我使用 serliaizer/deserializer 不正确,但找不到关于如何使用它的好文档

为了使用新的 serializers/deserializers,您需要初始化它们,例如:

from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import CSVSerializer

rcf_inference = rcf.deploy(
    initial_instance_count=1,
    instance_type='ml.m5.xlarge',
    serializer=CSVSerializer(),
    deserializer=JSONDeserializer()
)

在自定义序列化器的情况下,我们可以在 SageMaker 中这样做 2.x:

from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer


class FMSerializer(JSONSerializer):
    def serialize(self, data):
        js = {'instances': []}
        for row in data:
            js['instances'].append({'features': row.tolist()})
        return json.dumps(js)


predictor = estimator.deploy(
    initial_instance_count=1,
    instance_type="ml.m4.xlarge",
    serializer=FMSerializer(),
    deserializer=JSONDeserializer()
)