如何在 Spark 中获取数组列的所有组合?
How to get all combinations of an array column in Spark?
假设我有一个数组列 group_ids
+-------+----------+
|user_id|group_ids |
+-------+----------+
|1 |[5, 8] |
|3 |[1, 2, 3] |
|2 |[1, 4] |
+-------+----------+
架构:
root
|-- user_id: integer (nullable = false)
|-- group_ids: array (nullable = false)
| |-- element: integer (containsNull = false)
我想得到所有对的组合:
+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |[[5, 8]] |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |[[1, 4]] |
+-------+------------------------+
到目前为止,我使用 UDF 创建了最简单的解决方案:
spark.udf.register("permutate", udf((xs: Seq[Int]) => xs.combinations(2).toSeq))
dataset.withColumn("group_ids", expr("permutate(group_ids)"))
我要找的是通过 Spark 内置函数实现的东西。有没有不用 UDF 实现相同代码的方法?
一些高阶函数可以解决这个问题。需要 Spark >= 2.4.
val df2 = df.withColumn(
"group_ids",
expr("""
filter(
transform(
flatten(
transform(
group_ids,
x -> arrays_zip(
array_repeat(x, size(group_ids)),
group_ids
)
)
),
x -> array(x['0'], x['group_ids'])
),
x -> x[0] < x[1]
)
""")
)
df2.show(false)
+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |[[5, 8]] |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |[[1, 4]] |
+-------+------------------------+
基于explode
和joins
解决方案
val exploded = df.select(col("user_id"), explode(col("group_ids")).as("e"))
// to have combinations
val joined1 = exploded.as("t1")
.join(exploded.as("t2"), Seq("user_id"), "outer")
.select(col("user_id"), col("t1.e").as("e1"), col("t2.e").as("e2"))
// to filter out redundant combinations
val joined2 = joined1.as("t1")
.join(joined1.as("t2"), $"t1.user_id" === $"t2.user_id" && $"t1.e1" === $"t2.e2" && $"t1.e2"=== $"t2.e1")
.where("t1.e1 < t2.e1")
.select("t1.*")
// group into array
val result = joined2.groupBy("user_id")
.agg(collect_set(struct("e1", "e2")).as("group_ids"))
您可以获得列的最大大小group_ids
。然后,使用范围 (1 - maxSize)
和 when
表达式的组合从原始数组创建子数组组合,并最终从结果数组中过滤空元素:
val maxSize = df.select(max(size($"group_ids"))).first.getAs[Int](0)
val newCol = (1 to maxSize).combinations(2)
.map(c =>
when(
size($"group_ids") >= c(1),
array(element_at($"group_ids", c(0)), element_at($"group_ids", c(1)))
)
).toSeq
df.withColumn("group_ids", array(newCol: _*))
.withColumn("group_ids", expr("filter(group_ids, x -> x is not null)"))
.show(false)
//+-------+------------------------+
//|user_id|group_ids |
//+-------+------------------------+
//|1 |[[5, 8]] |
//|3 |[[1, 2], [1, 3], [2, 3]]|
//|2 |[[1, 4]] |
//+-------+------------------------+
假设我有一个数组列 group_ids
+-------+----------+
|user_id|group_ids |
+-------+----------+
|1 |[5, 8] |
|3 |[1, 2, 3] |
|2 |[1, 4] |
+-------+----------+
架构:
root
|-- user_id: integer (nullable = false)
|-- group_ids: array (nullable = false)
| |-- element: integer (containsNull = false)
我想得到所有对的组合:
+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |[[5, 8]] |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |[[1, 4]] |
+-------+------------------------+
到目前为止,我使用 UDF 创建了最简单的解决方案:
spark.udf.register("permutate", udf((xs: Seq[Int]) => xs.combinations(2).toSeq))
dataset.withColumn("group_ids", expr("permutate(group_ids)"))
我要找的是通过 Spark 内置函数实现的东西。有没有不用 UDF 实现相同代码的方法?
一些高阶函数可以解决这个问题。需要 Spark >= 2.4.
val df2 = df.withColumn(
"group_ids",
expr("""
filter(
transform(
flatten(
transform(
group_ids,
x -> arrays_zip(
array_repeat(x, size(group_ids)),
group_ids
)
)
),
x -> array(x['0'], x['group_ids'])
),
x -> x[0] < x[1]
)
""")
)
df2.show(false)
+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |[[5, 8]] |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |[[1, 4]] |
+-------+------------------------+
基于explode
和joins
解决方案
val exploded = df.select(col("user_id"), explode(col("group_ids")).as("e"))
// to have combinations
val joined1 = exploded.as("t1")
.join(exploded.as("t2"), Seq("user_id"), "outer")
.select(col("user_id"), col("t1.e").as("e1"), col("t2.e").as("e2"))
// to filter out redundant combinations
val joined2 = joined1.as("t1")
.join(joined1.as("t2"), $"t1.user_id" === $"t2.user_id" && $"t1.e1" === $"t2.e2" && $"t1.e2"=== $"t2.e1")
.where("t1.e1 < t2.e1")
.select("t1.*")
// group into array
val result = joined2.groupBy("user_id")
.agg(collect_set(struct("e1", "e2")).as("group_ids"))
您可以获得列的最大大小group_ids
。然后,使用范围 (1 - maxSize)
和 when
表达式的组合从原始数组创建子数组组合,并最终从结果数组中过滤空元素:
val maxSize = df.select(max(size($"group_ids"))).first.getAs[Int](0)
val newCol = (1 to maxSize).combinations(2)
.map(c =>
when(
size($"group_ids") >= c(1),
array(element_at($"group_ids", c(0)), element_at($"group_ids", c(1)))
)
).toSeq
df.withColumn("group_ids", array(newCol: _*))
.withColumn("group_ids", expr("filter(group_ids, x -> x is not null)"))
.show(false)
//+-------+------------------------+
//|user_id|group_ids |
//+-------+------------------------+
//|1 |[[5, 8]] |
//|3 |[[1, 2], [1, 3], [2, 3]]|
//|2 |[[1, 4]] |
//+-------+------------------------+