How to find grid neighbours (x, y as integers) group them and calculate mean of their values in spark


|     X|     Y|  value  |
|     1|     5|   1     |
|     1|     8|   1     |
|     1|     6|   6     |
|     2|     8|   5     |
|     2|     6|   3     |


|  neighbour_values | mean |
| (1,5)_(1,6)_(2,6) | 3.33 |
| (1,8)_(2,8)       | 3    |

这种将集合划分为不相交坐标的假设适用于 union-find 算法。

由于 union-find 是递归的,这种方法将原始元素收集到内存中并根据这些值创建 UDF。请注意,这可能会很慢 and/or 大型数据集需要大量内存。

// create example DF
val df = Seq((1, 5, 1), (1, 8, 1), (1, 6, 6), (2, 8, 5), (2, 6, 3)).toDF("x", "y", "value")

// collect all coordinates into in-memory collections
val coordinates = df.select("x", "y").collect().map(r => (r.getInt(0), r.getInt(1)))
val coordSet = coordinates.toSet

type K = (Int, Int)
val directParent:Map[K,Option[K]] = coordinates.map { case (x: Int, y: Int) =>
  val possibleParents = coordSet.intersect(Set((x - 1, y - 1), (x, y - 1), (x - 1, y)))
  val parent = if (possibleParents.isEmpty) None else Some(possibleParents.min)
  ((x, y), parent)

// skip unionFind if only concerned with direct neighbors
def unionFind(key: K, map:Map[K,Option[K]]): K = {
  val mapValue = map.get(key)
  mapValue.map(parentOpt => parentOpt match {
    case None => key
    case Some(parent) => unionFind(parent, map)

val canonicalUDF = udf((x: Int, y: Int) => unionFind((x, y), directParent))

// group using the canonical element
// create column "neighbors" based on x, y values in each group
val avgDF = df.groupBy(canonicalUDF($"x", $"y").alias("canonical")).agg(
  concat_ws("_", collect_list(concat(lit("("), $"x", lit(","), $"y", lit(")")))).alias("neighbors"),


avgDF.show(10, false)
|neighbors        |avg(value)        |
|(1,8)_(2,8)      |3.0               |