如何找到网格邻居(x,y 作为整数)将它们分组并计算它们在 spark 中的平均值

How to find grid neighbours (x, y as integers) group them and calculate mean of their values in spark

我正在努力寻找一种方法来从如下所示的数据集中计算邻居平均值:

+------+------+---------+
|     X|     Y|  value  |
+------+------+---------+
|     1|     5|   1     |
|     1|     8|   1     |
|     1|     6|   6     |
|     2|     8|   5     |
|     2|     6|   3     |
+------+------+---------+

例如:

(1, 5) 邻居将是 (1,6), (2,6) 所以我需要找到它们所有值的平均值,这里的答案是 (1 + 6 + 3) / 3 = 3.33

(1, 8) 个邻居将是 (2, 8),它们的平均值将是 (1 + 5) / 2 = 3

我希望我的解决方案看起来像这样(我只是在此处将坐标连接为字符串作为键):

+--------------------------+
|  neighbour_values | mean |
+--------------------------+
| (1,5)_(1,6)_(2,6) | 3.33 |
| (1,8)_(2,8)       | 3    |
+--------------------------+

我尝试过使用列连接,但效果似乎并不理想。 我正在考虑的解决方案之一是迭代 throw table 两次,一次用于元素,一次用于其他值,并检查它是否是邻居。不幸的是,我是 spark 的新手,我似乎找不到任何关于如何操作的信息。

非常感谢任何帮助! 谢谢!:))

答案取决于您是否只关心相邻邻居的分组。这种情况可能会导致歧义,例如,有一个大于两个项目的宽度或高度的连续块。因此,下面的方法假设一组连续坐标中的所有项目都被归为一个组,并且每个原始记录都属于一个组。

这种将集合划分为不相交坐标的假设适用于 union-find 算法。

由于 union-find 是递归的,这种方法将原始元素收集到内存中并根据这些值创建 UDF。请注意,这可能会很慢 and/or 大型数据集需要大量内存。

// create example DF
val df = Seq((1, 5, 1), (1, 8, 1), (1, 6, 6), (2, 8, 5), (2, 6, 3)).toDF("x", "y", "value")

// collect all coordinates into in-memory collections
val coordinates = df.select("x", "y").collect().map(r => (r.getInt(0), r.getInt(1)))
val coordSet = coordinates.toSet

type K = (Int, Int)
val directParent:Map[K,Option[K]] = coordinates.map { case (x: Int, y: Int) =>
  val possibleParents = coordSet.intersect(Set((x - 1, y - 1), (x, y - 1), (x - 1, y)))
  val parent = if (possibleParents.isEmpty) None else Some(possibleParents.min)
  ((x, y), parent)
}.toMap

// skip unionFind if only concerned with direct neighbors
def unionFind(key: K, map:Map[K,Option[K]]): K = {
  val mapValue = map.get(key)
  mapValue.map(parentOpt => parentOpt match {
    case None => key
    case Some(parent) => unionFind(parent, map)
  }).getOrElse(key)
}

val canonicalUDF = udf((x: Int, y: Int) => unionFind((x, y), directParent))

// group using the canonical element
// create column "neighbors" based on x, y values in each group
val avgDF = df.groupBy(canonicalUDF($"x", $"y").alias("canonical")).agg(
  concat_ws("_", collect_list(concat(lit("("), $"x", lit(","), $"y", lit(")")))).alias("neighbors"),
  avg($"value")).drop("canonical")

结果:

avgDF.show(10, false)
+-----------------+------------------+
|neighbors        |avg(value)        |
+-----------------+------------------+
|(1,8)_(2,8)      |3.0               |
|(1,5)_(1,6)_(2,6)|3.3333333333333335|
+-----------------+------------------+