用同一列的平均值填充 Pyspark 数据框列空值

Fill Pyspark dataframe column null values with average value from same column

有了这样的数据框,

rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,None,"201602"),(1,20,3003,"201601"), (1,20,None,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, None,"201601")])

df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
df_data.show()

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|null| 201602|
|  1|  20|3003| 201601|
|  1|  20|null| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|null| 201601|
+---+----+----+-------+

我需要用现有值的平均值填充空值,预期结果为

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

其中 1128 是现有值的平均值。我需要为多个专栏执行此操作。

我目前的做法是使用na.fill:

fill_values = {column: df_data.agg({column:"mean"}).flatMap(list).collect()[0] for column in df_data.columns if column not in ['date','id']}
df_data = df_data.na.fill(fill_values)

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

但这很麻烦。有什么想法吗?

好吧,无论如何你必须:

  • 计算统计数据
  • 填空

它几乎限制了你在这里真正可以改进的地方,仍然:

  • flatMap(list).collect()[0]替换为first()[0]或结构解包
  • 用一个动作计算所有统计数据
  • 使用内置的Row方法提取字典

最终结果可能是这样的:

def fill_with_mean(df, exclude=set()): 
    stats = df.agg(*(
        avg(c).alias(c) for c in df.columns if c not in exclude
    ))
    return df.na.fill(stats.first().asDict())

fill_with_mean(df_data, ["id", "date"])

在 Spark 2.2 或更高版本中,您还可以使用 Imputer。参见