Pyspark中如何使用when语句和array_contains根据条件创建新列?
How to use when statement and array_contains in Pyspark to create a new column based on conditions?
我正在尝试使用 filter
、case-when
语句和 array_contains
表达式来过滤和标记数据集中的列,并尝试以更高效的方式进行操作比我现在好多了。
我一直无法成功地将这 3 个元素串在一起,希望有人可以提供建议,因为我目前的方法有效但效率不高。
目前:
data = [["a", [1, 2, 3]], ["b", [1, 2, 8]], ["c", [3, 5, 4]], ["d", [8, 1, 4]]]
df = pd.DataFrame(data, columns=["product", "list_of_values"])
sdf = spark.createDataFrame(df)
# partially flag using array_contains to determine if element is within list_of_values
partially_flagged_sdf = (
sdf.withColumn(
"contains_element1",
spark_fns.array_contains(
sdf.list_of_values, "1"
),
)
.withColumn(
"contains_element2",
spark_fns.array_contains(
sdf.list_of_values, "2"
),
)
.withColumn(
"contains_element3",
spark_fns.array_contains(
sdf.list_of_values, "3"
),
)
.withColumn(
"contains_element4",
spark_fns.array_contains(
sdf.list_of_values, "4"
),
)
)
# using case_when and filtering, add additional flag if product == a, and list_of_values contains element 1 or 2
flagged_sdf = partially_flagged_sdf.withColumn("proda_contains_el1", spark_fns.when((spark_fns.col("product) == 'a') & & (
(spark_fns.col("contains_element1") == True)
| (spark_fns.col("contains_element2") == True)
)),True).otherwise(False)
您可以使用arrays_overlap
检查多个元素:
import pyspark.sql.functions as F
df2 = sdf.withColumn(
'newcol',
(F.col('product') == 'a') &
F.arrays_overlap('list_of_values', F.array(F.lit(1), F.lit(2)))
)
df2.show()
+-------+--------------+------+
|product|list_of_values|newcol|
+-------+--------------+------+
| a| [1, 2, 3]| true|
| b| [1, 2, 8]| false|
| c| [3, 5, 4]| false|
| d| [8, 1, 4]| false|
+-------+--------------+------+
我正在尝试使用 filter
、case-when
语句和 array_contains
表达式来过滤和标记数据集中的列,并尝试以更高效的方式进行操作比我现在好多了。
我一直无法成功地将这 3 个元素串在一起,希望有人可以提供建议,因为我目前的方法有效但效率不高。
目前:
data = [["a", [1, 2, 3]], ["b", [1, 2, 8]], ["c", [3, 5, 4]], ["d", [8, 1, 4]]]
df = pd.DataFrame(data, columns=["product", "list_of_values"])
sdf = spark.createDataFrame(df)
# partially flag using array_contains to determine if element is within list_of_values
partially_flagged_sdf = (
sdf.withColumn(
"contains_element1",
spark_fns.array_contains(
sdf.list_of_values, "1"
),
)
.withColumn(
"contains_element2",
spark_fns.array_contains(
sdf.list_of_values, "2"
),
)
.withColumn(
"contains_element3",
spark_fns.array_contains(
sdf.list_of_values, "3"
),
)
.withColumn(
"contains_element4",
spark_fns.array_contains(
sdf.list_of_values, "4"
),
)
)
# using case_when and filtering, add additional flag if product == a, and list_of_values contains element 1 or 2
flagged_sdf = partially_flagged_sdf.withColumn("proda_contains_el1", spark_fns.when((spark_fns.col("product) == 'a') & & (
(spark_fns.col("contains_element1") == True)
| (spark_fns.col("contains_element2") == True)
)),True).otherwise(False)
您可以使用arrays_overlap
检查多个元素:
import pyspark.sql.functions as F
df2 = sdf.withColumn(
'newcol',
(F.col('product') == 'a') &
F.arrays_overlap('list_of_values', F.array(F.lit(1), F.lit(2)))
)
df2.show()
+-------+--------------+------+
|product|list_of_values|newcol|
+-------+--------------+------+
| a| [1, 2, 3]| true|
| b| [1, 2, 8]| false|
| c| [3, 5, 4]| false|
| d| [8, 1, 4]| false|
+-------+--------------+------+