从任务中调用 Java/Scala 函数

Calling Java/Scala function from a task

背景

我最初的问题是 为什么在 map 函数中使用 DecisionTreeModel.predict 会引发异常? 并且与

有关

当我们使用 Scala API a recommended way 时,使用 DecisionTreeModel 获得 RDD[LabeledPoint] 的预测是简单地映射 RDD:

val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}

不幸的是,PySpark 中的类似方法效果不佳:

labelsAndPredictions = testData.map(
    lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()

Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transforamtion. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

而不是 official documentation 推荐这样的东西:

predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

那么这是怎么回事?这里没有广播变量,Scala API定义predict如下:

/**
 * Predict values for a single data point using the model trained.
 *
 * @param features array representing a single data point
 * @return Double prediction from the trained model
 */
def predict(features: Vector): Double = {
  topNode.predict(features)
}

/**
 * Predict values for the given data set using the model trained.
 *
 * @param features RDD representing data points to be predicted
 * @return RDD of predictions for each of the given data points
 */
def predict(features: RDD[Vector]): RDD[Double] = {
  features.map(x => predict(x))
}

所以至少乍一看,从操作或转换调用不是问题,因为预测似乎是本地操作。

说明

经过一番挖掘,我发现问题的根源是 JavaModelWrapper.call method invoked from DecisionTreeModel.predict. It access SparkContext,它需要调用 Java 函数:

callJavaFunc(self._sc, getattr(self._java_model, name), *a)

问题

对于 DecisionTreeModel.predict 的情况,有一个推荐的解决方法,所有必需的代码已经是 Scala 的一部分 API 但是一般来说,有什么优雅的方法来处理这样的问题吗?

目前能想到的解决方案比较重量级:

无法使用默认的 Py4J 网关进行通信。要理解为什么我们必须看一下 PySpark Internals 文档 [1] 中的下图:

由于 Py4J 网关在驱动程序上运行,Python 解释器无法访问它,它通过套接字与 JVM worker 通信(参见示例 PythonRDD / rdd.py)。

理论上可以为每个 worker 创建一个单独的 Py4J 网关,但实际上它不太可能有用。忽略诸如可靠性之类的问题 Py4J 根本就不是为执行数据密集型任务而设计的。

有什么解决方法吗?

  1. 使用Spark SQL Data Sources API包装JVM代码。

    优点:支持,高级别,不需要访问内部 PySpark API

    缺点:相对冗长且没有很好的记录,主要限于输入数据

  2. 使用 Scala UDF 在 DataFrame 上操作。

    优点:易于实现(参见Spark: How to map Python with Scala or Java User Defined Functions?),如果数据已经存储在DataFrame中,则Python和Scala之间没有数据转换,最小访问 Py4J

    缺点:需要访问 Py4J 网关和内部方法,限于 Spark SQL,难以调试,不支持

  3. 以类似于在 MLlib 中完成的方式创建高级 Scala 接口。

    优点:灵活,能够执行任意复杂代码。它可以直接在 RDD 上使用(参见示例 MLlib model wrappers) or with DataFrames (see )。后一种解决方案似乎更加友好,因为所有 ser-de 细节都已由现有 API.

    处理

    缺点:低级别,需要数据转换,与 UDF 一样需要访问 Py4J 和内部 API,不受支持

    可以在

  4. 中找到一些基本示例
  5. 使用外部工作流管理工具在 Python 和 Scala / Java 作业之间切换并将数据传递到 DFS。

    优点:易于实现,对代码本身的改动最小

    缺点:读取/写入数据的成本(Alluxio?)

  6. 使用共享 SQLContext(参见示例 Apache Zeppelin or Livy)使用已注册的临时表在来宾语言之间传递数据。

    优点:非常适合交互式分析

    缺点:对于批处理作业 (Zeppelin) 没有那么多,或者可能需要额外的编排 (Livy)


  1. 约书亚·罗森。 (2014 年 8 月 4 日)PySpark Internals. Retrieved from https://cwiki.apache.org/confluence/display/SPARK/PySpark+Internals