通过将数据框中的一行迭代到另一个数据框中的所有行来检查最小值

Check the minimum by iterating one row in a dataframe over all the rows in another dataframe

假设我有以下两个数据帧:

DF1:
+----------+----------+----------+
|     Place|Population|    IndexA|     
+----------+----------+----------+
|         A|       Int|       X_A|
|         B|       Int|       X_B|
|         C|       Int|       X_C|
+----------+----------+----------+

DF2:
+----------+----------+
|      City|    IndexB|     
+----------+----------+
|         D|       X_D|      
|         E|       X_E|   
|         F|       X_F|  
|      ....|      ....|
|        ZZ|      X_ZZ|
+----------+----------+

上面的数据帧通常要大得多。

我想确定从 DF1 到每个 Place 到哪个 City(DF2) 的最短距离。可以根据索引计算距离。因此,对于 DF1 中的每一行,我必须遍历 DF2 中的每一行,并根据索引的计算寻找最短距离。对于距离计算,定义了一个函数:

val distance = udf(
      (indexA: Long, indexB: Long) => {
        h3.instance.h3Distance(indexA, indexB)
      })

我尝试了以下方法:

val output =  DF1.agg(functions.min(distance(col("IndexA"), DF2.col("IndexB"))))

但是这个,代码编译但我得到以下错误:

Exception in thread "main" org.apache.spark.sql.AnalysisException: Resolved attribute(s)
H3Index#220L missing from Places#316,Population#330,IndexAx#338L in operator !Aggregate
[min(if ((isnull(IndexA#338L) OR isnull(IndexB#220L))) null else UDF(knownnotnull(IndexA#338L), knownnotnull(IndexB#220L))) AS min(UDF(IndexA, IndexB))#346].

所以我想我在从 DF1 中取出一行时迭代 DF2 中的每一行时做错了,但我找不到解决方案。

我做错了什么?我的方向对吗?

您收到此错误是因为您使用的索引列仅存在于 DF2 中,而不存在于您尝试执行聚合的 DF1 中。

为了使该字段可访问并确定与所有点的距离,您需要

  1. 交叉连接 DF1Df2Df1 的每个索引匹配 DF2
  2. 的每个索引
  3. 使用您的 udf 确定距离
  4. 在这个加入 udf 的新交叉上找到最小值

这可能看起来像:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, min, udf}

val distance = udf(
      (indexA: Long, indexB: Long) => {
        h3.instance.h3Distance(indexA, indexB)
      })

val resultDF = DF1.crossJoin(DF2)
    .withColumn("distance", distance(col("IndexA"), col("IndexB")))
    //instead of using a groupby then matching the min distance of the aggregation with the initial df. I've chosen to use a window function min to determine the min_distance of each group (determined by Place) and filter by the city with the min distance to each place
    .withColumn("min_distance", min("distance").over(Window.partitionBy("Place")))
    .where(col("distance") === col("min_distance"))
    .drop("min_distance")

这将生成一个数据框,其中的列来自数据框和附加列 distance

NB. 您当前将一个 df 中的每个项目与另一个 df 中的每个项目进行比较的方法是一项昂贵的操作。如果您有机会尽早过滤(例如加入启发式列,即可能指示某个地方可能更靠近城市的其他列),建议这样做。

让我知道这是否适合你。

如果你只有几个城市(少于或大约 1000 个),你可以通过在数组中收集城市然后为每个地方执行距离计算来避免 crossJoinWindow 随机这个收集的数组:

import org.apache.spark.sql.functions.{array_min, col, struct, transform, typedLit, udf}

val citiesIndexes = df2.select("City", "IndexB")
  .collect()
  .map(row => (row.getString(0), row.getLong(1)))

val result = df1.withColumn(
  "City",
  array_min(
    transform(
      typedLit(citiesIndexes),
      x => struct(distance(col("IndexA"), x.getItem("_2")), x.getItem("_1"))
    )
  ).getItem("col2")
)

这段代码适用于 Spark 3 及更高版本。如果您使用的 Spark 版本小于 3.0,则应将 array_min(...).getItem("col2") 部分替换为用户定义的函数。