在pyspark中应用一个udf过滤函数

Apply a udf filtering function in pyspark

我正在尝试对 pyspark 数据帧的特定范围内的值进行过滤和求和。当我使用此代码时它有效:

load_1=[]
for i in range(df.count()):
    start_t = df.select('start_time').where(df.id == i).collect()[0][0]
    try:
        load_1.append(df.where((df.start_time <= start_t) & (start_t <= df.end_time_1)).agg({"pkt_size":"sum"}).collect()[0][0])
    except:
        load_1.append(0)

但是速度很慢。我试图用 udf 加速它并这样做:

def get_load(a, df=df):
    try:
        return df.where((df.start_time <= a) & (a <= df.end_time_1)).agg({"pkt_size":"sum"}).collect()[0][0]
    except:
        return 0

loader = f.udf(get_load)
df.withColumn('load_1', loader(df.start_time).show())

使用此方法时出现此错误:

Could not serialize object: TypeError: can't pickle _thread.RLock objects

关于如何解决这个问题或如何加快它的速度有什么想法吗?我正在尝试做一些类似于我们在 pandas 中应用的函数。数据非常大(将近 40G),我可以使用的资源越多越好。 提前致谢!这是数据示例:

+---+-------------+-----------------+--------+
| id|   start_time|       end_time_1|pkt_size|
+---+-------------+-----------------+--------+
|  1|1000000000000| 1.00000000192E12|    66.0|
|  2|1000000000000| 1.00000000192E12|    66.0|
|  3|1000000006478|1.000000008398E12|    66.0|
|  4|1000000006478|1.000000008398E12|    66.0|
|  5|1000000012956|1.000000014556E12|    58.0|
|  6|1000000012956|1.000000014556E12|    58.0|
|  7|1000000012957|1.000000016156E12|  1518.0|
|  8|1000000012957|1.000000016156E12|  1518.0|
|  9|1000000012957|1.000000017756E12|  1518.0|
| 10|1000000012957|1.000000017756E12|  1518.0|
| 11|1000000012957|1.000000019356E12|  1518.0|
| 12|1000000012957|1.000000019356E12|  1518.0|
| 13|1000000012957|1.000000020956E12|  1518.0|
| 14|1000000012957|1.000000020956E12|  1518.0|
| 15|1000000012957|1.000000022556E12|  1518.0|
| 16|1000000012957|1.000000022556E12|  1518.0|
| 17|1000000012957|1.000000024156E12|  1518.0|
| 18|1000000012957|1.000000024156E12|  1518.0|
| 19|1000000012957|1.000000025756E12|  1518.0|
| 20|1000000012957|1.000000025756E12|  1518.0|
+---+-------------+-----------------+--------+
only showing top 20 rows

目标是对所有行的 pkt_size 求和 start_time 小于每个 id 的 start_time 并且 end_time 大于start_time 的 ID。所以过滤器是基于每行的start_time。

有一种方法可以不用循环或 udf 实现结果:

使用测试数据

+---+----------+----------+--------+                                            
| id|start_time|end_time_1|pkt_size|
+---+----------+----------+--------+
|  1|         2|         5|       4|
|  2|         1|         6|       5|
|  3|         1|         7|       6|
|  4|         5|         6|       7|
|  5|         4|         7|       8|
|  6|         3|         8|       9|
|  7|         6|         7|      10|
+---+----------+----------+--------+

代码

from pyspark.sql import functions as F

data = [(1, 2, 5, 4), 
        (2, 1, 6, 5),
        (3, 1, 7, 6),
        (4, 5, 6, 7),
        (5, 4, 7, 8),
        (6, 3, 8, 9),
        (7, 6, 7, 10)]

df = spark.createDataFrame(data, schema=['id', 'start_time', 'end_time_1', 'pkt_size'])

df.select('id', 'start_time', 'end_time_1') \
    .join(df.selectExpr('start_time as st_1', 'end_time_1 as et_1', 'pkt_size'), \
        F.expr('st_1 < start_time and et_1 > end_time_1')) \
    .groupBy('id') \
    .agg(F.sum('pkt_size')) \
    .show()

打印

+---+-------------+
| id|sum(pkt_size)|
+---+-------------+
|  7|            9|
|  5|            9|
|  1|           11|
|  4|           23|
+---+-------------+

在此示例中,对于 id 1,添加了第 2 行和第 3 行,对于 id 4,添加了第 3、5 和 6 行。

逻辑同题,只是所有id并行计算,而不是一个接一个计算。这种方法需要自连接,因此 Spark 集群应该足够大。