在 Spark UDF 函数中使用 Sagemaker 预测器

Using Sagemaker predictor in a Spark UDF function

我正在尝试 运行 从 Python Spark 作业推断部署在 SageMaker 上的 Tensorflow 模型。 我正在 运行ning 一个 (Databricks) 笔记本,其中包含以下单元格:

def call_predict():
        batch_size = 1
        data = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]]
        tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[batch_size, len(data[0])], dtype=tf.float32)      
        prediction = predictor.predict(tensor_proto)
        print("Process time: {}".format((time.clock() - start)))
        return prediction

如果我只调用 call_predict() 它工作正常:

call_predict()

我得到输出:

Process time: 65.261396
Out[61]: {'model_spec': {'name': u'generic_model',
  'signature_name': u'serving_default',
  'version': {'value': 1578909324L}},
 'outputs': {u'ages': {'dtype': 1,
   'float_val': [5.680944442749023],
   'tensor_shape': {'dim': [{'size': 1L}]}}}}

但是当我尝试从 Spark 上下文(在 UDF 中)调用时,出现序列化错误。 我尝试 运行 的代码是:

dataRange = range(1, 10001)
rangeRDD = sc.parallelize(dataRange, 8)
new_data = rangeRDD.map(lambda x : call_predict())
new_data.count()

我得到的错误是:

---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
<command-2282434> in <module>()
      2 rangeRDD = sc.parallelize(dataRange, 8)
      3 new_data = rangeRDD.map(lambda x : call_predict())
----> 4 new_data.count()
      5 

/databricks/spark/python/pyspark/rdd.pyc in count(self)
   1094         3
   1095         """
-> 1096         return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
   1097 
   1098     def stats(self):

/databricks/spark/python/pyspark/rdd.pyc in sum(self)
   1085         6.0
   1086         """
-> 1087         return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
   1088 
   1089     def count(self):

/databricks/spark/python/pyspark/rdd.pyc in fold(self, zeroValue, op)
    956         # zeroValue provided to each partition is unique from the one provided
    957         # to the final reduce call
--> 958         vals = self.mapPartitions(func).collect()
    959         return reduce(op, vals, zeroValue)
    960 

/databricks/spark/python/pyspark/rdd.pyc in collect(self)
    829         # Default path used in OSS Spark / for non-credential passthrough clusters:
    830         with SCCallSiteSync(self.context) as css:
--> 831             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    832         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    833 

/databricks/spark/python/pyspark/rdd.pyc in _jrdd(self)
   2573 
   2574         wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
-> 2575                                       self._jrdd_deserializer, profiler)
   2576         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
   2577                                              self.preservesPartitioning, self.is_barrier)

/databricks/spark/python/pyspark/rdd.pyc in _wrap_function(sc, func, deserializer, serializer, profiler)
   2475     assert serializer, "serializer should not be empty"
   2476     command = (func, profiler, deserializer, serializer)
-> 2477     pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
   2478     return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
   2479                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)

/databricks/spark/python/pyspark/rdd.pyc in _prepare_for_python_RDD(sc, command)
   2461     # the serialized command will be compressed by broadcast
   2462     ser = CloudPickleSerializer()
-> 2463     pickled_command = ser.dumps(command)
   2464     if len(pickled_command) > sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc):  # Default 1M
   2465         # The broadcast will have same life cycle as created PythonRDD

/databricks/spark/python/pyspark/serializers.pyc in dumps(self, obj)
    709                 msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
    710             cloudpickle.print_exec(sys.stderr)
--> 711             raise pickle.PicklingError(msg)
    712 
    713 

PicklingError: Could not serialize object: TypeError: can't pickle _ssl._SSLSocket objects

不确定这个序列化错误是什么 - 是抱怨未能反序列化 Predictor

我的笔记本有一个单元格,它在上述单元格之前被调用,具有以下导入:

import sagemaker
import boto3
from sagemaker.tensorflow.model import TensorFlowPredictor
import tensorflow as tf
import numpy as np
import time

预测器是使用以下代码创建的:

sagemaker_client = boto3.client('sagemaker', aws_access_key_id=ACCESS_KEY,
                                aws_secret_access_key=SECRET_KEY, region_name='us-east-1')
sagemaker_runtime_client = boto3.client('sagemaker-runtime', aws_access_key_id=ACCESS_KEY,
                                        aws_secret_access_key=SECRET_KEY, region_name='us-east-1')

boto_session = boto3.Session(region_name='us-east-1')
sagemaker_session = sagemaker.Session(boto_session, sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client)

predictor = TensorFlowPredictor('endpoint-poc', sagemaker_session)

udf函数会被多个spark任务并行执行。这些任务 运行 在完全隔离的 python 进程中,并且它们被安排到物理上不同的机器上。因此,每个数据,那些函数引用,都必须在同一个节点上。在 udf 中创建的所有内容都是这种情况。

无论何时从函数中引用 udf 之外的任何对象,都需要将此数据结构序列化(pickled)到每个执行程序。某些对象状态,例如与套接字的打开连接,无法进行 pickle。

您需要确保每个执行程序都延迟打开连接。它必须仅在该执行程序的第一个函数调用时发生。 connection pooling topic 包含在文档中,但仅在 spark streaming 指南中(尽管它也适用于正常的批处理作业)。

通常可以使用单例模式。但是在 python 人们使用 Borgh 模式。

class Env:
    _shared_state = {
        "sagemaker_client": None
        "sagemaker_runtime_client": None
        "boto_session": None
        "sagemaker_session": None
        "predictor": None
    }
    def __init__(self):
        self.__dict__ = self._shared_state
        if not self.predictor:
            self.sagemaker_client = boto3.client('sagemaker', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY, region_name='us-east-1')
            self.sagemaker_runtime_client = boto3.client('sagemaker-runtime', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY, region_name='us-east-1')

            self.boto_session = boto3.Session(region_name='us-east-1')
            self.sagemaker_session = sagemaker.Session(self.boto_session, sagemaker_client=self.sagemaker_client, sagemaker_runtime_client=self.sagemaker_runtime_client)

            self.predictor = TensorFlowPredictor('endpoint-poc', self.sagemaker_session)


#....
def call_predict():
   env = Env()
   batch_size = 1
   data = [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2]]
   tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[batch_size, len(data[0])], dtype=tf.float32)      
   prediction = env.predictor.predict(tensor_proto)

   print("Process time: {}".format((time.clock() - start)))
        return prediction

new_data = rangeRDD.map(lambda x : call_predict())

Env class 在主节点上定义。它的 _shared_state 有空条目。当 Env 对象第一次被实例化时,它在任何后续调用 udf 时与 Env 的所有其他实例共享状态。在每个单独的并行 运行ning 过程中,这将恰好发生一次。这样会话就被共享了,不需要 pickle。