pyspark sql 函数而不是 rdd distinct
pyspark sql functions instead of rdd distinct
我一直在尝试替换数据集中特定列的字符串。要么是 1 要么是 0,'Y' 如果是 1,否则是 0.
我已经通过 lambda 使用数据帧到 rdd 的转换,设法确定了要定位的列,但是需要一段时间才能处理。
为每一列完成一个到 rdd 的切换,然后执行一个 distinct,这需要一段时间!
如果 'Y' 存在于非重复结果集中,则该列被标识为需要转换。
我想知道是否有人可以建议我如何专门使用 pyspark sql 函数来获得相同的结果,而不必为每一列切换?
示例数据的代码如下:
import pyspark.sql.types as typ
import pyspark.sql.functions as func
col_names = [
('ALIVE', typ.StringType()),
('AGE', typ.IntegerType()),
('CAGE', typ.IntegerType()),
('CNT1', typ.IntegerType()),
('CNT2', typ.IntegerType()),
('CNT3', typ.IntegerType()),
('HE', typ.IntegerType()),
('WE', typ.IntegerType()),
('WG', typ.IntegerType()),
('DBP', typ.StringType()),
('DBG', typ.StringType()),
('HT1', typ.StringType()),
('HT2', typ.StringType()),
('PREV', typ.StringType())
]
schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names])
df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'),
('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'),
('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')]
,schema=schema)
cols = [(col.name, col.dataType) for col in df.schema]
transform_cols = []
for s in cols:
if s[1] == typ.StringType():
distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect()
if 'Y' in distinct_result:
transform_cols.append(s[0])
print(transform_cols)
输出为:
['ALIVE', 'DBG', 'HT2', 'PREV']
我设法使用 udf
来完成任务。首先,选择带有 Y
或 N
的列(这里我使用 func.first
以便浏览第一行):
cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict()
cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']]
# return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG']
接下来,您可以创建 udf
函数以将 Y
、N
映射到 1
、0
。
def map_input(val):
map_dict = dict(zip(['Y', 'N'], [1, 0]))
return map_dict.get(val)
udf_map_input = func.udf(map_input, returnType=typ.IntegerType())
for col in cols:
df = df.withColumn(col, udf_map_input(col))
df.show()
最后,您可以对列求和。然后我将输出转换为字典并检查哪些列的值大于 0(即包含 Y
)
out = df.select([func.sum(col).alias(col) for col in cols]).collect()
out = out[0]
print([col_name for (col_name, val) in out.asDict().items() if val > 0])
输出
['DBG', 'HT2', 'ALIVE', 'PREV']
我一直在尝试替换数据集中特定列的字符串。要么是 1 要么是 0,'Y' 如果是 1,否则是 0.
我已经通过 lambda 使用数据帧到 rdd 的转换,设法确定了要定位的列,但是需要一段时间才能处理。
为每一列完成一个到 rdd 的切换,然后执行一个 distinct,这需要一段时间!
如果 'Y' 存在于非重复结果集中,则该列被标识为需要转换。
我想知道是否有人可以建议我如何专门使用 pyspark sql 函数来获得相同的结果,而不必为每一列切换?
示例数据的代码如下:
import pyspark.sql.types as typ
import pyspark.sql.functions as func
col_names = [
('ALIVE', typ.StringType()),
('AGE', typ.IntegerType()),
('CAGE', typ.IntegerType()),
('CNT1', typ.IntegerType()),
('CNT2', typ.IntegerType()),
('CNT3', typ.IntegerType()),
('HE', typ.IntegerType()),
('WE', typ.IntegerType()),
('WG', typ.IntegerType()),
('DBP', typ.StringType()),
('DBG', typ.StringType()),
('HT1', typ.StringType()),
('HT2', typ.StringType()),
('PREV', typ.StringType())
]
schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names])
df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'),
('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'),
('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')]
,schema=schema)
cols = [(col.name, col.dataType) for col in df.schema]
transform_cols = []
for s in cols:
if s[1] == typ.StringType():
distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect()
if 'Y' in distinct_result:
transform_cols.append(s[0])
print(transform_cols)
输出为:
['ALIVE', 'DBG', 'HT2', 'PREV']
我设法使用 udf
来完成任务。首先,选择带有 Y
或 N
的列(这里我使用 func.first
以便浏览第一行):
cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict()
cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']]
# return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG']
接下来,您可以创建 udf
函数以将 Y
、N
映射到 1
、0
。
def map_input(val):
map_dict = dict(zip(['Y', 'N'], [1, 0]))
return map_dict.get(val)
udf_map_input = func.udf(map_input, returnType=typ.IntegerType())
for col in cols:
df = df.withColumn(col, udf_map_input(col))
df.show()
最后,您可以对列求和。然后我将输出转换为字典并检查哪些列的值大于 0(即包含 Y
)
out = df.select([func.sum(col).alias(col) for col in cols]).collect()
out = out[0]
print([col_name for (col_name, val) in out.asDict().items() if val > 0])
输出
['DBG', 'HT2', 'ALIVE', 'PREV']