使用 spark SQL 根据键匹配创建映射列数组

use spark SQL to create array of maps column based on key matching

我有两个 table:

entities

id | i | sources                        | name
----------------------------------------------------
1a | 0 | {"UK/bla": 1, "FR/blu": 2} | "mae"
1a | 1 | {"UK/bla": 1, "IT/bli": 2} | "coulson"

source_mapping

source_name | source_metadata
-----------------------------------------------------------------------------------------
"UK/bla"    | {"source_name": "UK/bla", "description": "this is a description"}
"FR/blu"    | {"source_name": "FR/blu", "description": "ceci est une description"}
"IT/bli"    | {"source_name": "IT/bli", "description": "questa è una descrizione"}

我想做的是在我的实体 table 中添加一列:

id | i | sources                        | name |  metadata   
---------------------------------------------------------------
1a | 0 | [{"UK/bla": 1}, {"FR/blu": 2}] | ...  | [{"source_name": "UK/bla", "description": "this is a description"}, {"source_name": "FR/blu", "description": "ceci est une description"}]
1a | 1 | [{"UK/bla": 1}, {"IT/bli": 2}] | ...  | [{"source_name": "UK/bla", "description": "this is a description"}, {"source_name": "IT/bli", "description": "questa è una descrizione"}]

我确实找到了一种方法来做到这一点:

entities_sources_exploded = (entities.select(F.col("id"), 
                                             F.col("i"),
                                             F.explode(F.col("sources")))
                                     .withColumnRenamed("key", "source_name")
                                     .drop("value"))  # get rid of it

entities_sources_exploded_with_metadata = (entities_sources_exploded
                                           .join(sources_mapping,
                                                 entities_sources_exploded.source_name == sources_mapping.source_name,
                                                 "left"))
entities_with_metadata = (entities_sources_exploded_with_metadata
                          .groupBy(F.col("id"), F.col("i"))
                          .agg(F.collect_list("source_metadata").alias("metadata")))

它有效 - 但我有偷偷摸摸的怀疑有一些方法可以做到这一点而不会爆炸并使用包裹在 .expr() 中的 spark SQL 中的 HOF - 我很想看看有人如何比我更流利可以解决这个问题。

我认为这应该可行:

import pandas as pd

# Setup data
data1 = pd.DataFrame({
    "id": ["1a", "1a"],
    "i": [0, 1],
    "sources": [{"UK/bla": 1, "FR/blu": 2}, {"UK/bla": 1, "IT/bli": 2}],
    "name": ["mae", "coulson"]
})
df1 = spark.createDataFrame(data1)
data2 = pd.DataFrame({
    "source_name": ["UK/bla", "FR/blu", "IT/bli"],
    "source_metadata": [
        {"source_name": "UK/bla", "description": "this is a description"},
        {"source_name": "FR/blu", "description": "ceci est une description"},
        {"source_name": "IT/bli", "description": "questa è una descrizione"}
    ]
})
df2 = spark.createDataFrame(data2)

# Create temp tables and execute SQL
df1.registerTempTable("df1")
df2.registerTempTable("df2")
query = """
    SELECT
        temp.id,
        temp.i,
        COLLECT_LIST(source) AS sources,
        temp.name,
        COLLECT_LIST(source_metadata) AS metadata
    FROM (
        SELECT
            *,
            map(key, value) AS source
        FROM (
            SELECT
                df1.id,
                df1.i,
                df1.name,
                EXPLODE(df1.sources)
            FROM df1
        ) AS df1_exploded
        JOIN df2
        ON df2.source_name = df1_exploded.key
    ) AS temp
    GROUP BY temp.id, temp.i, temp.name
"""
result = spark.sql(query)
result.show(5)