pyspark UDF 中 scipy.optimize.curve_fit 的异常处理

exception handling of scipy.optimize.curve_fit in pyspark UDF

我在 UDF 中有一个 scipy.optimize.curve_fit 调用可能会引发异常。有没有办法处理来自 UDF 外部的异常?

我尝试从 UDF 内部处理异常,但在 collect() 时未捕获异常。

我试过了:

import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from scipy.optimize import curve_fit

def fsigmoid(x, x0, l, k):
    return l / (1.0 + np.exp(-k*(x-x0)))

def curve_fitter(day_0, day_1, day_2, day_3, day_4, day_5, day_6):   
    try:
        # Find sigmoid parameters
        x = list(range(7))
        y = [day_0, day_1, day_2, day_3, day_4, day_5, day_6]
        param_bounds = [[0., 0, -10.], [6., 10., 10.]]
        (x_0, l, k), _ = curve_fit(fsigmoid, x, y, method='dogbox', bounds=(param_bounds[0], param_bounds[1]), maxfev=100)
    except IOError as e:
        (x_0, l, k) = (-1, -1, -1)
    return (float(x_0), float(l), float(k))

# Define UDF
udf_return_schema = StructType([
    StructField("x_0", FloatType(), True),
    StructField("l", FloatType(), True),
    StructField("k", FloatType(), True)
])
udf_curve_fitter = udf(curve_fitter, udf_return_schema)

# Define df and call UDF
data = [(1.6710683580253483, 3.7414496594802005, 5.186749035232343, 8.552623021374485, 0.4000450281109358, 1.7832269020250069, 8.578459510083448)]
df = spark.createDataFrame(data, ['day_' + str(i) for i in range(7)])

df.select([udf_curve_fitter(df['day_0'], df['day_1'], df['day_2'], df['day_3'], df['day_4'], df['day_5'], df['day_6'])]).collect()

我希望 udf_curve_fitter 变为 return (-1, -1, -1),但我得到:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-1-06f383278c7e> in <module>()
     30 data = [(1.6710683580253483, 3.7414496594802005, 5.186749035232343, 8.552623021374485, 0.4000450281109358, 1.7832269020250069, 8.578459510083448)]
     31 df = spark.createDataFrame(data, ['day_' + str(i) for i in range(7)])
---> 32 df.select([udf_curve_fitter(df['day_0'], df['day_1'], df['day_2'], df['day_3'], df['day_4'], df['day_5'], df['day_6'])]).collect()
[...]
Py4JJavaError: An error occurred while calling o130.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1151 in stage 0.0 failed 4 times, most recent failure: Lost task 1151.3 in stage 0.0 (TID 1154, ip-10-0-32-85.ec2.internal, executor 8): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 177, in main
    process()
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 172, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 220, in dump_stream
    self.serializer.dump_stream(self._batched(iterator), stream)
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 138, in dump_stream
    for obj in iterator:
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 209, in _batched
    for item in iterator:
  File "<string>", line 1, in <lambda>
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 69, in <lambda>
    return lambda *a: toInternal(f(*a))
  File "<ipython-input-1-06f383278c7e>", line 16, in curve_fitter
  File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/scipy/optimize/minpack.py", line 750, in curve_fit
    raise RuntimeError("Optimal parameters not found: " + res.message)
RuntimeError: Optimal parameters not found: The maximum number of function evaluations is exceeded.

您需要将例外更改为:RuntimeError 而不是 IOError