在 Scala 中应用条件修剪均值

Applying conditional trimmed mean in scala

我正在尝试为 Scala 中的每个组实现 80% 的修剪均值,以消除异常值。但这只有在该组中的记录数至少超过 10 条时才适用。

例如,

val sales = Seq(
  ("Warsaw", 2016, 100),
  ("Warsaw", 2017, 200),
  ("Boston", 2015, 50),
  ("Boston", 2016, 150),
  ("Toronto", 2017, 50)
).toDF("city", "year", "amount")

所以在这个数据集中,如果我正在做一个分组,

val groupByCityAndYear = sales
  .groupBy("city", "year").count() 
  .agg(avg($"amount").as("avg_amount"))

所以在这种情况下,如果计数超过 10,则应该去除异常值(可能会修剪 80% 均值),否则直接 avg($"amount")。我怎样才能做到这一点?

这是我得到的修剪均值的正确解释,以解释这种情况,

考虑什么是修剪均值:在典型情况下,您首先按递增顺序对数据进行排序。然后你从底部数到修剪百分比并丢弃这些值。例如,10% 的截尾平均值很常见;在这种情况下,您从最低值开始计数,直到您传递了集合中所有数据的 10%。低于该标记的值被搁置。同样,您从最高值开始倒数,直到超过您的修整百分比,然后将所有大于该值的值放在一边。你现在只剩下中间的 80%。你取那个的平均值,那就是你的 10% 修剪平均值

这可以用 window 函数来完成,但是会很昂贵:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("city", "year").orderBy("amount")

sales
  .withColumn("rn", row_number().over(w))
  .withColumn("count", count("*").over(w))
  .groupBy("city", "year")
  .agg(avg(when(
    ($"count" < 10) or ($"rn" between($"count" * 0.1, $"count" * 0.9)), 
    $"amount"
  )) as "avg_amount")