使用 udf 和 numpy 对 Pyspark 中的列表进行排序
Sort list in Pyspark using udf and numpy
我有一个 PySpark 数据框,其中第二列是列表列表。
以下是我拥有的 PySpark 数据框:
+---+------------------------------+
|A |B |
+---+------------------------------+
|a |[[95.0], [25.0, 25.0], [40.0]]|
|a |[[95.0], [20.0, 80.0]] |
|a |[[95.0], [25.0, 75.0]] |
|b |[[95.0], [25.0, 75.0]] |
|b |[[95.0], [12.0, 88.0]] |
+---+------------------------------+
在此示例中,我尝试展平数组(在第二列中),对数组进行排序并删除随后的 numpy 数组中的最大元素。
以下是我期望的输出:
+---+------------------------------+
|A |B |
+---+------------------------------+
|a |[25.0, 25.0, 40.0] |
|a |[20.0, 80.0] |
|a |[25.0, 75.0] |
|b |[25.0, 75.0] |
|b |[12.0, 88.0] |
+---+------------------------------+
下面是我目前拥有的udf:
def remove_highest(col):
return np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]
udf_remove_highest = F.udf( remove_highest , T.ArrayType() )
当我尝试创建此 udf 时出现以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-20-6984c2f41293> in <module>()
2 return np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]
3
----> 4 udf_remove_highest = F.udf( remove_highest , T.ArrayType() )
TypeError: __init__() missing 1 required positional argument: 'elementType'
我更喜欢使用 numpy 数组的 udf。我怎样才能达到上述目标?
要使您的代码正常工作,请执行以下操作:
Numpy 数组类型不支持作为 spark 数据帧的数据类型,因此当您返回转换后的数组时,向其添加一个 .tolist() 将其作为接受的 python 列表发送.并在 arraytype
中添加 floattype
def remove_highest(col):
return (np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]).tolist()
udf_remove_highest = F.udf( remove_highest , T.ArrayType(T.FloatType()) )
没有 udfs 的最有效方法。使用高阶函数:
这仅适用于 spark 2.4 及更高版本。
正在创建示例数据框:
from pyspark.sql import functions as F
from pyspark.sql.types import *
list=[['a',[[95.0], [25.0, 25.0], [40.0]]],
['a',[[95.0], [20.0, 80.0]]],
['a',[[95.0], [25.0, 75.0]]],
['b',[[95.0], [25.0, 75.0]]],
['b',[[95.0], [12.0, 88.0]]]]
cSchema = StructType([StructField("A", StringType())\
,StructField("B", ArrayType(ArrayType(FloatType())))])
df= spark.createDataFrame(list,schema=cSchema)
过滤表达式,用扁平化和 array_max:
expression="""filter(B, x -> x != C )"""
df1=df.withColumn("B",(F.sort_array(F.flatten("B")))).withColumn("C",F.array_max("B")).withColumn("B", F.expr(expression) )\
.drop("C")
df1.show()
输出:
+---+------------------+
| A| B|
+---+------------------+
| a|[25.0, 25.0, 40.0]|
| a| [20.0, 80.0]|
| a| [25.0, 75.0]|
| b| [25.0, 75.0]|
| b| [12.0, 88.0]|
+---+------------------+
我有一个 PySpark 数据框,其中第二列是列表列表。
以下是我拥有的 PySpark 数据框:
+---+------------------------------+
|A |B |
+---+------------------------------+
|a |[[95.0], [25.0, 25.0], [40.0]]|
|a |[[95.0], [20.0, 80.0]] |
|a |[[95.0], [25.0, 75.0]] |
|b |[[95.0], [25.0, 75.0]] |
|b |[[95.0], [12.0, 88.0]] |
+---+------------------------------+
在此示例中,我尝试展平数组(在第二列中),对数组进行排序并删除随后的 numpy 数组中的最大元素。
以下是我期望的输出:
+---+------------------------------+
|A |B |
+---+------------------------------+
|a |[25.0, 25.0, 40.0] |
|a |[20.0, 80.0] |
|a |[25.0, 75.0] |
|b |[25.0, 75.0] |
|b |[12.0, 88.0] |
+---+------------------------------+
下面是我目前拥有的udf:
def remove_highest(col):
return np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]
udf_remove_highest = F.udf( remove_highest , T.ArrayType() )
当我尝试创建此 udf 时出现以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-20-6984c2f41293> in <module>()
2 return np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]
3
----> 4 udf_remove_highest = F.udf( remove_highest , T.ArrayType() )
TypeError: __init__() missing 1 required positional argument: 'elementType'
我更喜欢使用 numpy 数组的 udf。我怎样才能达到上述目标?
要使您的代码正常工作,请执行以下操作:
Numpy 数组类型不支持作为 spark 数据帧的数据类型,因此当您返回转换后的数组时,向其添加一个 .tolist() 将其作为接受的 python 列表发送.并在 arraytype
中添加 floattypedef remove_highest(col):
return (np.sort( np.asarray([item for sublist in col for item in sublist]) )[:-1]).tolist()
udf_remove_highest = F.udf( remove_highest , T.ArrayType(T.FloatType()) )
没有 udfs 的最有效方法。使用高阶函数:
这仅适用于 spark 2.4 及更高版本。
正在创建示例数据框:
from pyspark.sql import functions as F
from pyspark.sql.types import *
list=[['a',[[95.0], [25.0, 25.0], [40.0]]],
['a',[[95.0], [20.0, 80.0]]],
['a',[[95.0], [25.0, 75.0]]],
['b',[[95.0], [25.0, 75.0]]],
['b',[[95.0], [12.0, 88.0]]]]
cSchema = StructType([StructField("A", StringType())\
,StructField("B", ArrayType(ArrayType(FloatType())))])
df= spark.createDataFrame(list,schema=cSchema)
过滤表达式,用扁平化和 array_max:
expression="""filter(B, x -> x != C )"""
df1=df.withColumn("B",(F.sort_array(F.flatten("B")))).withColumn("C",F.array_max("B")).withColumn("B", F.expr(expression) )\
.drop("C")
df1.show()
输出:
+---+------------------+
| A| B|
+---+------------------+
| a|[25.0, 25.0, 40.0]|
| a| [20.0, 80.0]|
| a| [25.0, 75.0]|
| b| [25.0, 75.0]|
| b| [12.0, 88.0]|
+---+------------------+