在 Window (Scala) 上识别列的重复值

Identifying recurring values a column over a Window (Scala)

我有一个包含两列的数据框:"ID" 和 "Amount",每行代表特定 ID 的交易和交易金额。我的示例使用以下 DF:

val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),
  (2, 50),(1, 30),(2, 120))).toDF("ID","Amount")

我想创建一个新列来标识所述金额是否为经常性值,即是否出现在同一 ID 的任何其他交易中。

我找到了一种更普遍的方法,即跨越整个列 "Amount",不考虑 ID,使用以下函数:

def recurring_amounts(df: DataFrame, col: String) : DataFrame = {
  var df_to_arr = df.select(col).rdd.map(r => r(0).asInstanceOf[Double]).collect()
  var arr_to_map = df_to_arr.groupBy(identity).mapValues(_.size)
  var map_to_df = arr_to_map.toSeq.toDF(col, "Count")
  var df_reformat = map_to_df.withColumn("Amount", $"Amount".cast(DoubleType))
  var df_out = df.join(df_reformat, Seq("Amount"))
  return df_new
}

val df_output = recurring_amounts(df, "Amount")

这个returns:

+---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  3  |
| 1 | 120  |  3  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  3  |
+---+------+-----+

然后我可以用它来创建我想要的二进制变量来指示该金额是否重复出现(如果 > 1 则为是,否则为否)。

然而,我的问题在这个例子中用值 120 来说明,它对于 ID 1 是重复出现的,但对于 ID 2 不是。因此我想要的输出是:

 +---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  2  |
| 1 | 120  |  2  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  1  |
+---+------+-----+

我一直在想办法应用函数 .over(Window.partitionBy("ID") 但不知道该怎么做。任何提示将不胜感激。

如果您在 sql 方面做得很好,您可以为您的 Dataframe 编写 sql 查询。您需要做的第一件事是在 spark 的内存中将您的 Dataframe 注册为 table。之后你可以在 table 上面写 sql。请注意,spark 是 spark 会话变量。

val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),(2, 50),(1, 30),(2, 120))).toDF("ID","Amount")
df.registerTempTable("transactions")
spark.sql("select *,count(*) over(partition by ID,Amount) as Count from transactions").show()

如果您有任何问题,请告诉我。