对 pyspark 数据帧中的列值应用阈值并将值转换为二进制 0 或 1

apply threshold on column values in a pysaprk dataframe and convert the values to binary 0 or 1

我有一个 PySpark 数据框

simpleData = [("person0",10, 10), \
    ("person1",1, 1), \
    ("person2",1, 0), \
    ("person3",5, 1), \
  ]
columns= ["persons_name","A", 'B']
exp = spark.createDataFrame(data = simpleData, schema = columns)

exp.printSchema()
exp.show()

看起来像

root
 |-- persons_name: string (nullable = true)
 |-- A: long (nullable = true)
 |-- B: long (nullable = true)
 |-- total: long (nullable = true)

+------------+---+---+
|persons_name|  A|  B|
+------------+---+---+
|     person0| 10| 10|
|     person1|  1|  1|   
|     person2|  1|  0|    
|     person3|  5|  1|    
+------------+---+---+

现在我想将值 2 的阈值应用于 A 列和 B 列的值,这样列中小于阈值的任何值都变为 0,大于阈值的值变为 1。

最终结果应该类似于-

+------------+---+---+
|persons_name|  A|  B|
+------------+---+---+
|     person0|  1|  1|
|     person1|  0|  0|   
|     person2|  0|  0|    
|     person3|  1|  0|    
+------------+---+---+

我怎样才能做到这一点?

threshold = 2
exp.select(
    [(F.col(col) > F.lit(threshold)).cast('int').alias(col) for col in ['A', 'B']]
)