如何使用均值和标准差从 pyspark 中的多个列中删除异常值
How to remove outliers from multiple columns in pyspark using mean and standard deviation
我有以下数据框,我想从定义的列中删除异常值。在下面的示例中,价格和收入。应为每组数据删除异常值。在此示例中,其 'cd' 和 'segment' 列。应根据 5 个标准差删除异常值。
data = [
('a', '1',20,10),
('a', '1',30,16),
('a', '1',50,91),
('a', '1',60,34),
('a', '1',200,23),
('a', '2',33,87),
('a', '2',86,90),
('a','2',89,35),
('a', '2',90,24),
('a', '2',40,97),
('a', '2',1,21),
('b', '1',45,96),
('b', '1',56,99),
('b', '1',89,23),
('b', '1',98,64),
('b', '2',86,42),
('b', '2',45,54),
('b', '2',67,95),
('b','2',86,70),
('b', '2',91,64),
('b', '2',2,53),
('b', '2',4,87)
]
data = (spark.createDataFrame(data, ['cd','segment','price','income']))
我已经使用下面的代码删除异常值,但这只适用于一列。
mean_std = (
data
.groupBy('cd', 'segment')
.agg(
*[f.mean(colName).alias('{}{}'.format('mean_',colName)) for colName in ['price']],
*[f.stddev(colName).alias('{}{}'.format('stddev_',colName)) for colName in ['price']])
)
mean_columns = ['mean_price']
std_columns = ['stddev_price']
upper = mean_std
for col_1 in mean_columns:
for col_2 in std_columns:
if col_1 != col_2:
name = col_1 + '_upper_limit'
upper = upper.withColumn(name, f.col(col_1) + f.col(col_2)*5)
lower = upper
for col_1 in mean_columns:
for col_2 in std_columns:
if col_1 != col_2:
name = col_1 + '_lower_limit'
lower = lower.withColumn(name, f.col(col_1) - f.col(col_2)*5)
outliers = (data.join(lower,
how = 'left',
on = ['cd', 'segment'])
.withColumn('is_outlier_price', f.when((f.col('price')>f.col('mean_price_upper_limit')) |
(f.col('price')<f.col('mean_price_lower_limit')),1)
.otherwise(None))
)
我的最终输出应该为每个变量有一列,说明它是 1 = 删除还是 0 = 保留。
非常感谢对此的任何帮助。
您可以使用 F.when 的列表理解。
您的问题的一个非常简单的示例:
import pyspark.sql.functions as F
tst1= sqlContext.createDataFrame([(1,2,3,4,1,10),(1,3,5,7,2,11),(9,9,10,6,2,9),(2,4,90,9,1,2),(2,10,3,4,1,7),(3,5,11,5,7,8),(10,9,12,6,7,9),(3,6,99,8,1,9)],schema=['val1','val1_low_lim','val1_upper_lim','val2','val2_low_lim','val2_upper_lim'])
tst_res = tst1.select(tst1.columns+[(F.when((F.col(coln)<F.col(coln+'_upper_lim'))&(F.col(coln)>F.col(coln+'_low_lim')),1).otherwise(0)).alias(coln+'_valid') for coln in tst1.columns if "_lim" not in coln ])
结果:
tst_res.show()
+----+------------+--------------+----+------------+--------------+----------+----------+
|val1|val1_low_lim|val1_upper_lim|val2|val2_low_lim|val2_upper_lim|val1_valid|val2_valid|
+----+------------+--------------+----+------------+--------------+----------+----------+
| 1| 2| 3| 4| 1| 10| 0| 1|
| 1| 3| 5| 7| 2| 11| 0| 1|
| 9| 9| 10| 6| 2| 9| 0| 1|
| 2| 4| 90| 9| 1| 2| 0| 0|
| 2| 10| 3| 4| 1| 7| 0| 1|
| 3| 5| 11| 5| 7| 8| 0| 0|
| 10| 9| 12| 6| 7| 9| 1| 0|
| 3| 6| 99| 8| 1| 9| 0| 1|
+----+------------+--------------+----+------------+--------------+----------+----------+
您的代码几乎可以 100% 正常工作。您所要做的就是用列名数组替换单个固定列名,然后循环遍历该数组:
numeric_cols = ['price', 'income']
mean_std = \
data \
.groupBy('cd', 'segment') \
.agg( \
*[F.mean(colName).alias('mean_{}'.format(colName)) for colName in numeric_cols],\
*[F.stddev(colName).alias('stddev_{}'.format(colName)) for colName in numeric_cols])
mean_std
现在是一个数据框,每个 numeric_cols
.
的元素有两列(mean_...
和 stddev_...
)
下一步我们计算numeric_cols
的每个元素的下限和上限:
mean_std_min_max = mean_std
for colName in numeric_cols:
meanCol = 'mean_{}'.format(colName)
stddevCol = 'stddev_{}'.format(colName)
minCol = 'min_{}'.format(colName)
maxCol = 'max_{}'.format(colName)
mean_std_min_max = mean_std_min_max.withColumn(minCol, F.col(meanCol) - 5 * F.col(stddevCol))
mean_std_min_max = mean_std_min_max.withColumn(maxCol, F.col(meanCol) + 5 * F.col(stddevCol))
mean_std_min_max
现在包含 numeric_cols
.
的每个元素的两个附加列 min_...
和 max...
最后是连接,然后像以前一样计算 is_outliers_...
列:
outliers = data.join(mean_std_min_max, how = 'left', on = ['cd', 'segment'])
for colName in numeric_cols:
isOutlierCol = 'is_outlier_{}'.format(colName)
minCol = 'min_{}'.format(colName)
maxCol = 'max_{}'.format(colName)
meanCol = 'mean_{}'.format(colName)
stddevCol = 'stddev_{}'.format(colName)
outliers = outliers.withColumn(isOutlierCol, F.when((F.col(colName) > F.col(maxCol)) | (F.col(colName) < F.col(minCol)), 1).otherwise(0))
outliers = outliers.drop(minCol,maxCol, meanCol, stddevCol)
循环的最后一行只是清理并删除中间列。将其注释掉可能会有所帮助。
最后的结果是:
+---+-------+-----+------+----------------+-----------------+
| cd|segment|price|income|is_outlier_price|is_outlier_income|
+---+-------+-----+------+----------------+-----------------+
| b| 2| 86| 42| 0| 0|
| b| 2| 45| 54| 0| 0|
| b| 2| 67| 95| 0| 0|
| b| 2| 86| 70| 0| 0|
| b| 2| 91| 64| 0| 0|
+---+-------+-----+------+----------------+-----------------+
only showing top 5 rows
我有以下数据框,我想从定义的列中删除异常值。在下面的示例中,价格和收入。应为每组数据删除异常值。在此示例中,其 'cd' 和 'segment' 列。应根据 5 个标准差删除异常值。
data = [
('a', '1',20,10),
('a', '1',30,16),
('a', '1',50,91),
('a', '1',60,34),
('a', '1',200,23),
('a', '2',33,87),
('a', '2',86,90),
('a','2',89,35),
('a', '2',90,24),
('a', '2',40,97),
('a', '2',1,21),
('b', '1',45,96),
('b', '1',56,99),
('b', '1',89,23),
('b', '1',98,64),
('b', '2',86,42),
('b', '2',45,54),
('b', '2',67,95),
('b','2',86,70),
('b', '2',91,64),
('b', '2',2,53),
('b', '2',4,87)
]
data = (spark.createDataFrame(data, ['cd','segment','price','income']))
我已经使用下面的代码删除异常值,但这只适用于一列。
mean_std = (
data
.groupBy('cd', 'segment')
.agg(
*[f.mean(colName).alias('{}{}'.format('mean_',colName)) for colName in ['price']],
*[f.stddev(colName).alias('{}{}'.format('stddev_',colName)) for colName in ['price']])
)
mean_columns = ['mean_price']
std_columns = ['stddev_price']
upper = mean_std
for col_1 in mean_columns:
for col_2 in std_columns:
if col_1 != col_2:
name = col_1 + '_upper_limit'
upper = upper.withColumn(name, f.col(col_1) + f.col(col_2)*5)
lower = upper
for col_1 in mean_columns:
for col_2 in std_columns:
if col_1 != col_2:
name = col_1 + '_lower_limit'
lower = lower.withColumn(name, f.col(col_1) - f.col(col_2)*5)
outliers = (data.join(lower,
how = 'left',
on = ['cd', 'segment'])
.withColumn('is_outlier_price', f.when((f.col('price')>f.col('mean_price_upper_limit')) |
(f.col('price')<f.col('mean_price_lower_limit')),1)
.otherwise(None))
)
我的最终输出应该为每个变量有一列,说明它是 1 = 删除还是 0 = 保留。
非常感谢对此的任何帮助。
您可以使用 F.when 的列表理解。 您的问题的一个非常简单的示例:
import pyspark.sql.functions as F
tst1= sqlContext.createDataFrame([(1,2,3,4,1,10),(1,3,5,7,2,11),(9,9,10,6,2,9),(2,4,90,9,1,2),(2,10,3,4,1,7),(3,5,11,5,7,8),(10,9,12,6,7,9),(3,6,99,8,1,9)],schema=['val1','val1_low_lim','val1_upper_lim','val2','val2_low_lim','val2_upper_lim'])
tst_res = tst1.select(tst1.columns+[(F.when((F.col(coln)<F.col(coln+'_upper_lim'))&(F.col(coln)>F.col(coln+'_low_lim')),1).otherwise(0)).alias(coln+'_valid') for coln in tst1.columns if "_lim" not in coln ])
结果:
tst_res.show()
+----+------------+--------------+----+------------+--------------+----------+----------+
|val1|val1_low_lim|val1_upper_lim|val2|val2_low_lim|val2_upper_lim|val1_valid|val2_valid|
+----+------------+--------------+----+------------+--------------+----------+----------+
| 1| 2| 3| 4| 1| 10| 0| 1|
| 1| 3| 5| 7| 2| 11| 0| 1|
| 9| 9| 10| 6| 2| 9| 0| 1|
| 2| 4| 90| 9| 1| 2| 0| 0|
| 2| 10| 3| 4| 1| 7| 0| 1|
| 3| 5| 11| 5| 7| 8| 0| 0|
| 10| 9| 12| 6| 7| 9| 1| 0|
| 3| 6| 99| 8| 1| 9| 0| 1|
+----+------------+--------------+----+------------+--------------+----------+----------+
您的代码几乎可以 100% 正常工作。您所要做的就是用列名数组替换单个固定列名,然后循环遍历该数组:
numeric_cols = ['price', 'income']
mean_std = \
data \
.groupBy('cd', 'segment') \
.agg( \
*[F.mean(colName).alias('mean_{}'.format(colName)) for colName in numeric_cols],\
*[F.stddev(colName).alias('stddev_{}'.format(colName)) for colName in numeric_cols])
mean_std
现在是一个数据框,每个 numeric_cols
.
mean_...
和 stddev_...
)
下一步我们计算numeric_cols
的每个元素的下限和上限:
mean_std_min_max = mean_std
for colName in numeric_cols:
meanCol = 'mean_{}'.format(colName)
stddevCol = 'stddev_{}'.format(colName)
minCol = 'min_{}'.format(colName)
maxCol = 'max_{}'.format(colName)
mean_std_min_max = mean_std_min_max.withColumn(minCol, F.col(meanCol) - 5 * F.col(stddevCol))
mean_std_min_max = mean_std_min_max.withColumn(maxCol, F.col(meanCol) + 5 * F.col(stddevCol))
mean_std_min_max
现在包含 numeric_cols
.
min_...
和 max...
最后是连接,然后像以前一样计算 is_outliers_...
列:
outliers = data.join(mean_std_min_max, how = 'left', on = ['cd', 'segment'])
for colName in numeric_cols:
isOutlierCol = 'is_outlier_{}'.format(colName)
minCol = 'min_{}'.format(colName)
maxCol = 'max_{}'.format(colName)
meanCol = 'mean_{}'.format(colName)
stddevCol = 'stddev_{}'.format(colName)
outliers = outliers.withColumn(isOutlierCol, F.when((F.col(colName) > F.col(maxCol)) | (F.col(colName) < F.col(minCol)), 1).otherwise(0))
outliers = outliers.drop(minCol,maxCol, meanCol, stddevCol)
循环的最后一行只是清理并删除中间列。将其注释掉可能会有所帮助。
最后的结果是:
+---+-------+-----+------+----------------+-----------------+
| cd|segment|price|income|is_outlier_price|is_outlier_income|
+---+-------+-----+------+----------------+-----------------+
| b| 2| 86| 42| 0| 0|
| b| 2| 45| 54| 0| 0|
| b| 2| 67| 95| 0| 0|
| b| 2| 86| 70| 0| 0|
| b| 2| 91| 64| 0| 0|
+---+-------+-----+------+----------------+-----------------+
only showing top 5 rows