获取pyspark中数组列中所有True元素的索引
get index of all True elements in array column in pyspark
我有:
country | sources | infer_from_source
---------------------------------------------------------------------------
null | ["LUX", "CZE","CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"]
"DEU" | ["DEU"] | ["FALSE"]
功能后我想要什么:
country | sources | infer_from_source | inferred_country
------------------------------------------------------------------------------------------------
null | ["LUX", "CZE", "CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"] | ["CZE", "FRA"]
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
我需要创建一个函数
如果 country
列为空,根据 infer_from_source
列数组中的布尔值从 sources
数组中提取国家,否则它应该返回 country
值。
我创建了这个函数
from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
idx = infer_from_source.index("TRUE")
return sources[idx]
return None
但这会产生 - 基本上 .index("TRUE")
方法 returns 仅匹配其参数的第一个元素的索引。
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | "CZE"
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
已修复!只是一个列表理解问题
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
max_ix = len(infer_from_source)
true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
return [sources[ix] for ix in true_index_array]
return None
只要使用 Spark 内置函数就可以实现相同的功能,就应该避免使用 UDF,尤其是涉及到 Pyspark UDF 时。
这是在数组上使用高阶函数 transform
+ filter
的另一种方法:
import pyspark.sql.functions as F
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("""filter(
transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
x -> x is not null
)""")
)
)
df1.show()
#+-------+--------------------+--------------------+----------------+
#|country| sources| infer_from_source|inferred_country|
#+-------+--------------------+--------------------+----------------+
#| null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...| [CZE, FRA]|
#| DEU| [DEU]| [FALSE]| [DEU]|
#+-------+--------------------+--------------------+----------------+
从 Spark 3+ 开始,您可以在过滤器 lambda 函数中使用索引:
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
)
)
我有:
country | sources | infer_from_source
---------------------------------------------------------------------------
null | ["LUX", "CZE","CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"]
"DEU" | ["DEU"] | ["FALSE"]
功能后我想要什么:
country | sources | infer_from_source | inferred_country
------------------------------------------------------------------------------------------------
null | ["LUX", "CZE", "CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"] | ["CZE", "FRA"]
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
我需要创建一个函数
如果 country
列为空,根据 infer_from_source
列数组中的布尔值从 sources
数组中提取国家,否则它应该返回 country
值。
我创建了这个函数
from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
idx = infer_from_source.index("TRUE")
return sources[idx]
return None
但这会产生 - 基本上 .index("TRUE")
方法 returns 仅匹配其参数的第一个元素的索引。
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | "CZE"
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
已修复!只是一个列表理解问题
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
max_ix = len(infer_from_source)
true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
return [sources[ix] for ix in true_index_array]
return None
只要使用 Spark 内置函数就可以实现相同的功能,就应该避免使用 UDF,尤其是涉及到 Pyspark UDF 时。
这是在数组上使用高阶函数 transform
+ filter
的另一种方法:
import pyspark.sql.functions as F
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("""filter(
transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
x -> x is not null
)""")
)
)
df1.show()
#+-------+--------------------+--------------------+----------------+
#|country| sources| infer_from_source|inferred_country|
#+-------+--------------------+--------------------+----------------+
#| null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...| [CZE, FRA]|
#| DEU| [DEU]| [FALSE]| [DEU]|
#+-------+--------------------+--------------------+----------------+
从 Spark 3+ 开始,您可以在过滤器 lambda 函数中使用索引:
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
)
)