在 PySpark 中用平均值填充缺失值
Filling missing values with mean in PySpark
我正在尝试使用 PySpark 用平均值填充 NaN 值。
下面是我正在使用的代码,下面是发生的错误:
from pyspark.sql.functions import avg
def fill_with_mean(df_1, exclude=set()):
stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
return df_1.na.fill(stats.first().asDict())
res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
res.show()
错误:
Py4JJavaError Traceback (most recent call last)
<ipython-input-35-42f4d984f022> in <module>()
3 stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
4 return df_1.na.fill(stats.first().asDict())
----> 5 res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
6 res.show()
5 frames
/usr/local/lib/python3.7/dist-packages/py4j/protocol.py in get_return_value(answer,
gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o376.fill.
: java.lang.NullPointerException
at org.apache.spark.sql.DataFrameNaFunctions.$anonfun$fillMap(DataFrameNaFunctions.scala:418)
at scala.collection.TraversableLike.$anonfun$map(TraversableLike.scala:286)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.DataFrameNaFunctions.fillMap(DataFrameNaFunctions.scala:407)
at org.apache.spark.sql.DataFrameNaFunctions.fill(DataFrameNaFunctions.scala:232)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
你能告诉我哪里错了吗?有没有其他方法可以使用均值填充缺失值?
这是我的数据框的样子:
我希望看到用平均值代替空值。另外,Evaporation 和 sunshine 也不是完全没有的,里面还有其他的值。
数据集是一个csv文件:
from pyspark.sql.functions import *
import pyspark
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","
df_1= spark.read.format("csv").option("header","true").load('/content/weatherAUS.csv')
df_1.show()
来源:https://www.kaggle.com/jsphyg/weather-dataset-rattle-package
根据您的输入数据,我创建了我的数据框:
from pyspark.sql import functions as F, Window
df = spark.read.csv("./weatherAUS.csv", header=True, inferSchema=True, nullValue="NA")
然后,我处理整个数据框,不包括您提到的列 + 无法替换的列(日期和位置)
exclude = ["date", "location"] + ["mintemp", "maxtemp", "evaporation", "sunshine"]
df2 = df.select(
*(
F.coalesce(F.col(col), F.avg(col).over(Window.orderBy(F.lit(1)))).alias(col)
if col.lower() not in exclude
else F.col(col)
for col in df.columns
)
)
df2.show(5)
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
| Date| Location|MinTemp|MaxTemp|Rainfall|Evaporation|Sunshine|WindGustDir|WindGustSpeed|WindDir9am|WindDir3pm|WindSpeed9am|WindSpeed3pm|Humidity9am|Humidity3pm|Pressure9am|Pressure3pm|Cloud9am|Cloud3pm|Temp9am|Temp3pm|RainToday|RainTomorrow|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
|2012-07-02 22:00:00|Townsville| 12.4| 23.3| 0.0| 6.0| 10.8| SSW| 33.0| SE| S| 7.0| 20.0| 34.0| 28.0| 1019.5| 1015.5| 1.0| 2.0| 17.5| 23.0| No| No|
|2012-07-03 22:00:00|Townsville| 9.1| 21.7| 0.0| 5.0| 10.9| SE| 39.0| SSW| SSE| 17.0| 20.0| 26.0| 14.0| 1021.7| 1018.4| 1.0| 0.0| 16.4| 21.2| No| No|
|2012-07-04 22:00:00|Townsville| 8.2| 23.4| 0.0| 5.2| 10.6| SSW| 30.0| SSW| NE| 22.0| 13.0| 34.0| 40.0| 1021.7| 1018.5| 2.0| 2.0| 17.1| 22.3| No| No|
|2012-07-05 22:00:00|Townsville| 10.5| 24.5| 0.0| 6.0| 10.2| E| 39.0| SSW| SE| 11.0| 17.0| 48.0| 31.0| 1021.2| 1017.2| 1.0| 2.0| 17.9| 23.8| No| No|
|2012-07-06 22:00:00|Townsville| 17.7| 24.1| 0.0| 6.8| 0.5| SE| 54.0| SE| ESE| 19.0| 31.0| 69.0| 58.0| 1019.2| 1017.0| 8.0| 7.0| 20.1| 23.2| No| No|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
only showing top 5 rows
您可以使用插补估计器 Imputer
:
df = spark.createDataFrame([(1.0, float("nan")),
(2.0, float("nan")),
(float("nan"), 3.0),
(4.0, 4.0),
(5.0, 5.0)], ["a", "b"])
df.show()
+---+---+
| a| b|
+---+---+
|1.0|NaN|
|2.0|NaN|
|NaN|3.0|
|4.0|4.0|
|5.0|5.0|
+---+---+
import pyspark.ml.feature as MF
imputer = MF.Imputer(strategy='mean', inputCols=['a', 'b'], outputCols=['out_a', 'out_b'])
model = imputer.fit(df)
model.transform(df).show()
+---+---+-----+-----+
| a| b|out_a|out_b|
+---+---+-----+-----+
|1.0|NaN| 1.0| 4.0|
|2.0|NaN| 2.0| 4.0|
|NaN|3.0| 3.0| 3.0|
|4.0|4.0| 4.0| 4.0|
|5.0|5.0| 5.0| 5.0|
+---+---+-----+-----+
使用链接方法:
(Imputer().
setStrategy('mean').
setInputCols(['a', 'b']).
setOutputCols(['out_a', 'out_b']).
fit(df).
transform(df).
show())
我正在尝试使用 PySpark 用平均值填充 NaN 值。 下面是我正在使用的代码,下面是发生的错误:
from pyspark.sql.functions import avg
def fill_with_mean(df_1, exclude=set()):
stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
return df_1.na.fill(stats.first().asDict())
res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
res.show()
错误:
Py4JJavaError Traceback (most recent call last)
<ipython-input-35-42f4d984f022> in <module>()
3 stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
4 return df_1.na.fill(stats.first().asDict())
----> 5 res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
6 res.show()
5 frames
/usr/local/lib/python3.7/dist-packages/py4j/protocol.py in get_return_value(answer,
gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o376.fill.
: java.lang.NullPointerException
at org.apache.spark.sql.DataFrameNaFunctions.$anonfun$fillMap(DataFrameNaFunctions.scala:418)
at scala.collection.TraversableLike.$anonfun$map(TraversableLike.scala:286)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.DataFrameNaFunctions.fillMap(DataFrameNaFunctions.scala:407)
at org.apache.spark.sql.DataFrameNaFunctions.fill(DataFrameNaFunctions.scala:232)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
你能告诉我哪里错了吗?有没有其他方法可以使用均值填充缺失值?
这是我的数据框的样子:
我希望看到用平均值代替空值。另外,Evaporation 和 sunshine 也不是完全没有的,里面还有其他的值。
数据集是一个csv文件:
from pyspark.sql.functions import *
import pyspark
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","
df_1= spark.read.format("csv").option("header","true").load('/content/weatherAUS.csv')
df_1.show()
来源:https://www.kaggle.com/jsphyg/weather-dataset-rattle-package
根据您的输入数据,我创建了我的数据框:
from pyspark.sql import functions as F, Window
df = spark.read.csv("./weatherAUS.csv", header=True, inferSchema=True, nullValue="NA")
然后,我处理整个数据框,不包括您提到的列 + 无法替换的列(日期和位置)
exclude = ["date", "location"] + ["mintemp", "maxtemp", "evaporation", "sunshine"]
df2 = df.select(
*(
F.coalesce(F.col(col), F.avg(col).over(Window.orderBy(F.lit(1)))).alias(col)
if col.lower() not in exclude
else F.col(col)
for col in df.columns
)
)
df2.show(5)
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
| Date| Location|MinTemp|MaxTemp|Rainfall|Evaporation|Sunshine|WindGustDir|WindGustSpeed|WindDir9am|WindDir3pm|WindSpeed9am|WindSpeed3pm|Humidity9am|Humidity3pm|Pressure9am|Pressure3pm|Cloud9am|Cloud3pm|Temp9am|Temp3pm|RainToday|RainTomorrow|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
|2012-07-02 22:00:00|Townsville| 12.4| 23.3| 0.0| 6.0| 10.8| SSW| 33.0| SE| S| 7.0| 20.0| 34.0| 28.0| 1019.5| 1015.5| 1.0| 2.0| 17.5| 23.0| No| No|
|2012-07-03 22:00:00|Townsville| 9.1| 21.7| 0.0| 5.0| 10.9| SE| 39.0| SSW| SSE| 17.0| 20.0| 26.0| 14.0| 1021.7| 1018.4| 1.0| 0.0| 16.4| 21.2| No| No|
|2012-07-04 22:00:00|Townsville| 8.2| 23.4| 0.0| 5.2| 10.6| SSW| 30.0| SSW| NE| 22.0| 13.0| 34.0| 40.0| 1021.7| 1018.5| 2.0| 2.0| 17.1| 22.3| No| No|
|2012-07-05 22:00:00|Townsville| 10.5| 24.5| 0.0| 6.0| 10.2| E| 39.0| SSW| SE| 11.0| 17.0| 48.0| 31.0| 1021.2| 1017.2| 1.0| 2.0| 17.9| 23.8| No| No|
|2012-07-06 22:00:00|Townsville| 17.7| 24.1| 0.0| 6.8| 0.5| SE| 54.0| SE| ESE| 19.0| 31.0| 69.0| 58.0| 1019.2| 1017.0| 8.0| 7.0| 20.1| 23.2| No| No|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
only showing top 5 rows
您可以使用插补估计器 Imputer
:
df = spark.createDataFrame([(1.0, float("nan")),
(2.0, float("nan")),
(float("nan"), 3.0),
(4.0, 4.0),
(5.0, 5.0)], ["a", "b"])
df.show()
+---+---+
| a| b|
+---+---+
|1.0|NaN|
|2.0|NaN|
|NaN|3.0|
|4.0|4.0|
|5.0|5.0|
+---+---+
import pyspark.ml.feature as MF
imputer = MF.Imputer(strategy='mean', inputCols=['a', 'b'], outputCols=['out_a', 'out_b'])
model = imputer.fit(df)
model.transform(df).show()
+---+---+-----+-----+
| a| b|out_a|out_b|
+---+---+-----+-----+
|1.0|NaN| 1.0| 4.0|
|2.0|NaN| 2.0| 4.0|
|NaN|3.0| 3.0| 3.0|
|4.0|4.0| 4.0| 4.0|
|5.0|5.0| 5.0| 5.0|
+---+---+-----+-----+
使用链接方法:
(Imputer().
setStrategy('mean').
setInputCols(['a', 'b']).
setOutputCols(['out_a', 'out_b']).
fit(df).
transform(df).
show())