使用 Spark 获取值超过某个阈值的所有列的名称
Using Spark to get names of all columns that have a value over some threshold
背景
我们正在将数据从 Redshift 卸载到 S3,然后将其加载到数据帧中,如下所示:
df = spark.read.csv(path, schema=schema, sep='|')
我们将 PySpark 和 AWS EMR(版本 5.4.0)与 Spark 2.1.0 一起使用。
问题
我有一个 Redshift table 正在以 CSV 格式读入 PySpark。记录采用这种格式:
url,category1,category2,category3,category4
http://example.com,0.6,0.0,0.9,0.3
url 是 VARCHAR,类别 值是介于 0.0 和 1.0 之间的 FLOAT。
我想要做的是生成一个新的 DataFrame,每个类别只有一行,其中原始数据集中的值高于某个阈值 X。例如,如果阈值设置为 0.5,那么我想要我的新数据集如下所示:
url,category
http://example.com,category1
http://example.com,category3
我是 Spark/PySpark 的新手,所以我不确定 how/if 这是否可行,如有任何帮助,我们将不胜感激!
编辑:
想添加我的解决方案(基于 Pushkr 的代码)。我们要加载大量类别,以避免对每个类别进行硬编码 select 我做了以下操作:
parsed_df = None
for column in column_list:
if not parsed_df:
parsed_df = df.select(df.url, when(df[column]>threshold,column).otherwise('').alias('cat'))
else:
parsed_df = parsed_df.union(df.select(df.url, when(df[column]>threshold,column).otherwise('')))
if parsed_df is not None:
parsed_df = parsed_df.filter(col('cat') != '')
其中 column_list 是先前生成的类别列名称列表,threshold 是 select 类别。
再次感谢!
这是我试过的东西 -
data = [('http://example.com',0.6,0.0,0.9,0.3),('http://example1.com',0.6,0.0,0.9,0.3)]
df = spark.createDataFrame(data)\
.toDF('url','category1','category2','category3','category4')
from pyspark.sql.functions import *
df\
.select(df.url,when(df.category1>0.5,'category1').otherwise('').alias('category'))\
.union(\
df.select(df.url,when(df.category2>0.5,'category2').otherwise('')))\
.union(\
df.select(df.url,when(df.category3>0.5,'category3').otherwise('')))\
.union(\
df.select(df.url,when(df.category4>0.5,'category4').otherwise('')))\
.filter(col('category')!= '')\
.show()
输出:
+-------------------+---------+
| url| category|
+-------------------+---------+
| http://example.com|category1|
|http://example1.com|category1|
| http://example.com|category3|
|http://example1.com|category3|
+-------------------+---------+
背景
我们正在将数据从 Redshift 卸载到 S3,然后将其加载到数据帧中,如下所示:
df = spark.read.csv(path, schema=schema, sep='|')
我们将 PySpark 和 AWS EMR(版本 5.4.0)与 Spark 2.1.0 一起使用。
问题
我有一个 Redshift table 正在以 CSV 格式读入 PySpark。记录采用这种格式:
url,category1,category2,category3,category4
http://example.com,0.6,0.0,0.9,0.3
url 是 VARCHAR,类别 值是介于 0.0 和 1.0 之间的 FLOAT。
我想要做的是生成一个新的 DataFrame,每个类别只有一行,其中原始数据集中的值高于某个阈值 X。例如,如果阈值设置为 0.5,那么我想要我的新数据集如下所示:
url,category
http://example.com,category1
http://example.com,category3
我是 Spark/PySpark 的新手,所以我不确定 how/if 这是否可行,如有任何帮助,我们将不胜感激!
编辑:
想添加我的解决方案(基于 Pushkr 的代码)。我们要加载大量类别,以避免对每个类别进行硬编码 select 我做了以下操作:
parsed_df = None
for column in column_list:
if not parsed_df:
parsed_df = df.select(df.url, when(df[column]>threshold,column).otherwise('').alias('cat'))
else:
parsed_df = parsed_df.union(df.select(df.url, when(df[column]>threshold,column).otherwise('')))
if parsed_df is not None:
parsed_df = parsed_df.filter(col('cat') != '')
其中 column_list 是先前生成的类别列名称列表,threshold 是 select 类别。
再次感谢!
这是我试过的东西 -
data = [('http://example.com',0.6,0.0,0.9,0.3),('http://example1.com',0.6,0.0,0.9,0.3)]
df = spark.createDataFrame(data)\
.toDF('url','category1','category2','category3','category4')
from pyspark.sql.functions import *
df\
.select(df.url,when(df.category1>0.5,'category1').otherwise('').alias('category'))\
.union(\
df.select(df.url,when(df.category2>0.5,'category2').otherwise('')))\
.union(\
df.select(df.url,when(df.category3>0.5,'category3').otherwise('')))\
.union(\
df.select(df.url,when(df.category4>0.5,'category4').otherwise('')))\
.filter(col('category')!= '')\
.show()
输出:
+-------------------+---------+
| url| category|
+-------------------+---------+
| http://example.com|category1|
|http://example1.com|category1|
| http://example.com|category3|
|http://example1.com|category3|
+-------------------+---------+