Pyspark:如何过滤 MapType 列上的数据框? (如 isin() 的风格)

Pyspark: How to filter a Dataframe on a MapType column? (as in the style of isin() )

当我想以 isin() 的样式过滤 MapType 列上的 Dataframe 时,最佳策略是什么?

所以基本上我想获取数据帧的所有行,其中 MapType 列的内容与 MapType-“实例”列表中的条目之一匹配。也可以是该列的连接,但到目前为止我尝试的所有方法都失败了,因为 EqualTo does not support ordering on type map.

除了使用 isin() 或 join() 的直接方法外,我还想到了使用 to_json() 将地图转储到 json 然后过滤 Json 字符串,但这似乎是随机排列键,所以这个字符串比较也不可靠?

有没有我遗漏的简单的东西?您建议如何解决这个问题?

示例 df:

+----+---------------------------------------------------------+
|key |metric                                                   |
+----+---------------------------------------------------------+
|123k|Map(metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6)      |
|d23d|Map(metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2)      |
|as3d|Map(metric1 -> 2.2, metric2 -> 4.3, metric3 -> 9.0)      |
+----+---------------------------------------------------------+

过滤器(伪代码):

df.where(metric.isin([
 Map(metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6),
 Map(metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2)
])

期望的输出:

----+---------------------------------------------------------+
|key |metric                                                   |
+----+---------------------------------------------------------+
|123k|Map(metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6)      |
|d23d|Map(metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2)      |
+----+---------------------------------------------------------+

这不是比较映射相等性的最优雅方法:您可以收集映射键,比较两个映射中每个键的值,并确保所有值都相同。我想最好构造一个过滤器 df,然后进行半连接,而不是使用 isin:

传递它们

样本 df 和过滤器 df:

df.show(truncate=False)
+----+------------------------------------------------+
|key |metric                                          |
+----+------------------------------------------------+
|123k|[metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6]|
|d23d|[metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2]|
|as3d|[metric1 -> 2.2, metric2 -> 4.3, metric3 -> 9.0]|
+----+------------------------------------------------+

filter_df = df.select('metric').limit(2)
filter_df.show(truncate=False)
+------------------------------------------------+
|metric                                          |
+------------------------------------------------+
|[metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6]|
|[metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2]|
+------------------------------------------------+

过滤方式:

import pyspark.sql.functions as F

result = df.alias('df').join(
    filter_df.alias('filter_df'),
    F.expr("""
        aggregate(
            transform(
                concat(map_keys(df.metric), map_keys(filter_df.metric)),
                x -> filter_df.metric[x] = df.metric[x]
            ),
            true,
            (acc, x) -> acc and x
        )"""),
     'left_semi'
)

result.show(truncate=False)
+----+------------------------------------------------+
|key |metric                                          |
+----+------------------------------------------------+
|123k|[metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6]|
|d23d|[metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2]|
+----+------------------------------------------------+

比较 spark 中的 2 个映射列不是那么明显。对于第一个地图中的每个键,您需要检查第二个地图中是否具有相同的值。键也一样。

使用 UDF 可能更简单,因为在 Python 你可以检查 dict 相等性:

from pyspark.sql import functions as F

map_equals = F.udf(lambda x, y: x == y, BooleanType())

# create map1 literal to filter with
map1 = F.create_map(*[
    F.lit(x) for x in chain(*{"metric1": 1.3, "metric2": 6.3, "metric3": 7.6}.items())
])

df1 = df.filter(map_equals("metric", map1))

df1.show(truncate=False)

#+----+------------------------------------------------+
#|key |metric                                          |
#+----+------------------------------------------------+
#|123k|[metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6]|
#+----+------------------------------------------------+

另一种方法是将要过滤的地图文字添加为列,并检查 metric 中的每个键是否从该文字地图中获得相同的值。

这是一个使用 transfromarray_min 映射键数组来创建过滤器表达式的示例。 (如果 array_min returns true 这意味着所有值都相等):

filter_map_literal = F.create_map(*[
    F.lit(x) for x in chain(*{"metric1": 1.3, "metric2": 6.3, "metric3": 7.6}.items())
])

df1 = df.withColumn("filter_map", filter_map_literal).filter(
    F.array_min(F.expr("""transform(map_keys(metric),
                           x -> if(filter_map[x] = metric[x], true, false)
                    )""")
                )
).drop("filter_map")