使用 udf 和 numpy 对 Pyspark 中的列表进行排序

Sort list in Pyspark using udf and numpy

我有一个 PySpark 数据框,其中第二列是列表列表。

以下是我拥有的 PySpark 数据框:

+---+------------------------------+
|A  |B                             |
+---+------------------------------+
|a  |[[95.0], [25.0, 25.0], [40.0]]|
|a  |[[95.0], [20.0, 80.0]]        |
|a  |[[95.0], [25.0, 75.0]]        |
|b  |[[95.0], [25.0, 75.0]]        |
|b  |[[95.0], [12.0, 88.0]]        |
+---+------------------------------+

在此示例中,我尝试展平数组(在第二列中),对数组进行排序并删除随后的 numpy 数组中的最大元素

以下是我期望的输出:

+---+------------------------------+
|A  |B                             |
+---+------------------------------+
|a  |[25.0, 25.0, 40.0]            |
|a  |[20.0, 80.0]                  |
|a  |[25.0, 75.0]                  |
|b  |[25.0, 75.0]                  |
|b  |[12.0, 88.0]                  |
+---+------------------------------+

下面是我目前拥有的udf:

def remove_highest(col):
    return np.sort( np.asarray([item for sublist in col for item in sublist])  )[:-1]

udf_remove_highest = F.udf( remove_highest , T.ArrayType() )

当我尝试创建此 udf 时出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-6984c2f41293> in <module>()
      2     return np.sort( np.asarray([item for sublist in col for item in sublist])  )[:-1]
      3 
----> 4 udf_remove_highest = F.udf( remove_highest , T.ArrayType() )

TypeError: __init__() missing 1 required positional argument: 'elementType'

我更喜欢使用 numpy 数组的 udf。我怎样才能达到上述目标?

要使您的代码正常工作,请执行以下操作:

Numpy 数组类型不支持作为 spark 数据帧的数据类型,因此当您返回转换后的数组时,向其添加一个 .tolist() 将其作为接受的 python 列表发送.并在 arraytype

中添加 floattype
def remove_highest(col):
    return (np.sort( np.asarray([item for sublist in col for item in sublist])  )[:-1]).tolist()

udf_remove_highest = F.udf( remove_highest , T.ArrayType(T.FloatType()) )

没有 udfs 的最有效方法。使用高阶函数:

这仅适用于 spark 2.4 及更高版本。

正在创建示例数据框:

from pyspark.sql import functions as F
from pyspark.sql.types import *

list=[['a',[[95.0], [25.0, 25.0], [40.0]]],
      ['a',[[95.0], [20.0, 80.0]]],
      ['a',[[95.0], [25.0, 75.0]]],
      ['b',[[95.0], [25.0, 75.0]]],
      ['b',[[95.0], [12.0, 88.0]]]]

cSchema = StructType([StructField("A", StringType())\
                      ,StructField("B", ArrayType(ArrayType(FloatType())))])
df= spark.createDataFrame(list,schema=cSchema)

过滤表达式,用扁平化和 array_max:

expression="""filter(B, x -> x != C )"""
df1=df.withColumn("B",(F.sort_array(F.flatten("B")))).withColumn("C",F.array_max("B")).withColumn("B", F.expr(expression) )\
.drop("C")
df1.show()

输出:

+---+------------------+
|  A|                 B|
+---+------------------+
|  a|[25.0, 25.0, 40.0]|
|  a|      [20.0, 80.0]|
|  a|      [25.0, 75.0]|
|  b|      [25.0, 75.0]|
|  b|      [12.0, 88.0]|
+---+------------------+