有效地扩展行数组以分隔列

efficiently expand array of Row to separate columns

我有一个 spark 数据框,其中一个字段是行结构数组。我需要将其扩展到自己的专栏中。问题之一是在数组中,有时会缺少一个字段。

示例如下:

from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql import functions as udf

spark = SparkSession.builder.getOrCreate()

# data
rows = [{'status':'active','member_since':1990,'info':[Row(tag='name',value='John'),Row(tag='age',value='50'),Row(tag='phone',value='1234567')]},
        {'status':'inactive','member_since':2000,'info':[Row(tag='name',value='Tom'),Row(tag='phone',value='1234567')]},
        {'status':'active','member_since':2015,'info':[Row(tag='name',value='Steve'),Row(tag='age',value='28')]}]

# create dataframe
df = spark.createDataFrame(rows)

# transform info to dict
to_dict = udf.UserDefinedFunction(lambda s:dict(s),MapType(StringType(),StringType()))
df = df.withColumn("info_dict",to_dict("info"))

# extract name, NA if not exists
extract_name = udf.UserDefinedFunction(lambda s:s.get("name","NA"))
df = df.withColumn("name",extract_name("info_dict"))

# extract age, NA if not exists
extract_age = udf.UserDefinedFunction(lambda s:s.get("age","NA"))
df = df.withColumn("age",extract_age("info_dict"))

# extract phone, NA if not exists
extract_phone = udf.UserDefinedFunction(lambda s:s.get("phone","NA"))
df = df.withColumn("phone",extract_phone("info_dict"))

df.show()

可以看到'Tom',缺少'age';对于 'Steve',缺少 'phone'。像上面的代码片段一样,我目前的解决方案是首先将数组转换为字典,然后将每个单独的字段解析到它们的列中。结果是这样的:

+--------------------+------------+--------+--------------------+-----+---+-------+
|                info|member_since|  status|           info_dict| name|age|  phone|
+--------------------+------------+--------+--------------------+-----+---+-------+
|[[name, John], [a...|        1990|  active|[name -> John, ph...| John| 50|1234567|
|[[name, Tom], [ph...|        2000|inactive|[name -> Tom, pho...|  Tom| NA|1234567|
|[[name, Steve], [...|        2015|  active|[name -> Steve, a...|Steve| 28|     NA|
+--------------------+------------+--------+--------------------+-----+---+-------+

我真的只想要列 'status'、'member_since'、'name'、'age' 和 'phone'。由于 UDF,此解决方案有效但速度很慢。有没有更快的选择?谢谢

我可以想到 2 种使用 DataFrame 函数执行此操作的方法。我相信第一个应该更快,但代码不那么优雅。第二个更紧凑,但可能慢。

方法一:动态创建地图

此方法的核心是将您的 Row 变成 MapType()。这可以使用 pyspark.sql.functions.create_map() 和一些魔法使用 functools.reduce()operator.add() 来实现。

from operator import add
import pyspark.sql.functions as f

f.create_map(
    *reduce(
        add,
        [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
         for k in range(3)]
    )
)

问题是没有一种方法(AFAIK)可以动态确定 WrappedArray 的长度或 它以一种简单的方式。如果缺少一个值,这将导致错误,因为映射键不能是 null。然而,由于我们知道列表可以包含 1、2、3 个元素,因此我们可以针对每种情况进行测试。

df.withColumn(
    'map',
    f.when(f.size(f.col('info')) == 1, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(1)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 2, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(2)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 3, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(3)]
            )
        )
    )))
).select(
    ['member_since', 'status'] + [f.col("map").getItem(k).alias(k) for k in keys]
).show(truncate=False)

最后一步使用 中描述的方法将 'map' 键转换为列。

这会产生以下输出:

+------------+--------+-----+----+-------+
|member_since|status  |name |age |phone  |
+------------+--------+-----+----+-------+
|1990        |active  |John |50  |1234567|
|2000        |inactive|Tom  |null|1234567|
|2015        |active  |Steve|28  |null   |
+------------+--------+-----+----+-------+

方法 2:使用 explode、groupBy 和 pivot

首先在 'info' 列上使用 pyspark.sql.functions.explode(),然后使用 'tag''value' 列作为 create_map() 的参数:

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', 'map')\
    .show(truncate=False)
#+------------+------------+--------+---------------------+
#|id          |member_since|status  |map                  |
#+------------+------------+--------+---------------------+
#|85899345920 |1990        |active  |Map(name -> John)    |
#|85899345920 |1990        |active  |Map(age -> 50)       |
#|85899345920 |1990        |active  |Map(phone -> 1234567)|
#|180388626432|2000        |inactive|Map(name -> Tom)     |
#|180388626432|2000        |inactive|Map(phone -> 1234567)|
#|266287972352|2015        |active  |Map(name -> Steve)   |
#|266287972352|2015        |active  |Map(age -> 28)       |
#+------------+------------+--------+---------------------+

我还使用 pyspark.sql.functions.monotonically_increasing_id() 添加了一个列 'id' 以确保我们可以跟踪哪些行属于同一记录。

现在我们可以展开地图列、groupBy()pivot()。我们可以使用 pyspark.sql.functions.first() 作为 groupBy() 的聚合函数,因为我们知道每个组中只有一个 'value'

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', f.explode('map'))\
    .groupBy('id', 'member_since', 'status').pivot('key').agg(f.first('value'))\
    .select('member_since', 'status', 'age', 'name', 'phone')\
    .show()
#+------------+--------+----+-----+-------+
#|member_since|  status| age| name|  phone|
#+------------+--------+----+-----+-------+
#|        1990|  active|  50| John|1234567|
#|        2000|inactive|null|  Tom|1234567|
#|        2015|  active|  28|Steve|   null|
#+------------+--------+----+-----+-------+