处理条件并将结果提供给新的数据框

Processing with conditions and feeding result to a new dataframe

我正在尝试找到最优化的 spark 方式来进行一些处理以提供输出 df 例如

if input['col1'] == True and input['col5'] > 20 :
    output df['id'] = input['id']
    output df['colA'] = input['col3'] * 3

if input['col6'] == True :
    output df['id'] = input['id']
    output df['colB'] = input['col4'] * 3

我尝试使用 collect() 并迭代行,同时将结果行存储在字典中,但我认为肯定有更好的方法

input df

+----------+-------+--------------------+--------------------+-----------+---------------+
|id        |col1   |                col3|                col4|col5       |col6           |
+----------+-------+--------------------+--------------------+-----------+---------------+
| 268494441|  false|  28.996884149891063|   94.28749693687607|         30|          false|
| 268534191|  false|    30.0700790355414|  129.64494323730952|         45|          false|
| 268579145|  false|    2.89968841498775|   9.428749693683198|         35|          false|
| 268579191|  false|   7.249221037472766|  23.571874234219017|         30|          false|
| 268571197|  false|   39.84244119591372|  160.95414726813047|         30|          false|
| 268547741|  false|  4.4759630169045845|    16.5906900326544|         20|          false|
| 268547767|  false|  7.4027371799987485|   28.51122204533751|         20|          false|
| 268554748|  false|   38.94343543800605|  226.39507142785138|         40|          false|
| 268559131|  false|  13.771003930703246|   80.53786233547791|         35|          false|
| 268559155|  false|  18.144459509697892|  106.11542843914333|         40|          false|
| 268559440|  false|  11.512269225121075|   51.85172168942658|         45|          false|
| 268566392|  false|  23.304755883656764|    86.3818534005261|         35|          false|
| 268569952|  false|    8.10189596578023|   47.38284770172526|         30|           true|
+----------+-------+--------------------+--------------------+-----------+---------------+

output_df

|-- id: integer (nullable = true)
 |-- colA: float (nullable = true)
 |-- colB: float (nullable = true)

试试这个:

import pyspark.sql.functions as f
output_df = (
    input_df
    .withColumn('colA', f.when((f.col('col1') == True) & (f.col('col5') > 20), f.col('col3') * 3))
    .withColumn('colB', f.when(f.col('col6') == True, f.col('col4') * 3))
    .where(f.col('colA').isNotNull() | f.col('colB').isNotNull())
    .select('id', 'colA', 'colB')
)