获取pyspark中数组列中所有True元素的索引

get index of all True elements in array column in pyspark

我有:

country | sources                     |  infer_from_source   
---------------------------------------------------------------------------
null    | ["LUX", "CZE","CHN", "FRA"] |  ["FALSE", "TRUE", "FALSE", "TRUE"]      
"DEU"   | ["DEU"]                     |  ["FALSE"]          

功能后我想要什么:

country | sources                      |  infer_from_source                   | inferred_country
------------------------------------------------------------------------------------------------
null    | ["LUX", "CZE", "CHN", "FRA"] |  ["FALSE", "TRUE", "FALSE", "TRUE"]  | ["CZE", "FRA"]
"DEU"   | ["DEU"]                      |  ["FALSE"]                           | "DEU"

我需要创建一个函数

如果 country 列为空,根据 infer_from_source 列数组中的布尔值从 sources 数组中提取国家,否则它应该返回 country值。

我创建了这个函数

from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F


@udf
def determine_entity_country(country: StringType, sources: ArrayType, 
                             infer_from_source: ArrayType) -> ArrayType:
    if country:
        return country_value
    else:
       if "TRUE" in infer_from_source:
           idx = infer_from_source.index("TRUE")
           return sources[idx]
  return None

但这会产生 - 基本上 .index("TRUE") 方法 returns 仅匹配其参数的第一个元素的索引。

country | sources        |  infer_from_source   | inferred_country
--------------------------------------------------------------------
null    | ["LUX", "CZE", |  ["FALSE", "TRUE",   | 
        |  "CHN", "FRA"] |   "FALSE", "TRUE"]   | "CZE"
"DEU"   | ["DEU"]        |  ["FALSE"]           | "DEU"

已修复!只是一个列表理解问题

@udf
def determine_entity_country(country: StringType, sources: ArrayType, 
                             infer_from_source: ArrayType) -> ArrayType:
    if country:
        return country_value
    else:
       if "TRUE" in infer_from_source:
            max_ix = len(infer_from_source)
            true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
            return [sources[ix] for ix in true_index_array] 
  return None

只要使用 Spark 内置函数就可以实现相同的功能,就应该避免使用 UDF,尤其是涉及到 Pyspark UDF 时。

这是在数组上使用高阶函数 transform + filter 的另一种方法:

import pyspark.sql.functions as F

df1 = df.withColumn(
    "inferred_country",
    F.when(
        F.col("country").isNotNull(),
        F.array(F.col("country"))
    ).otherwise(
        F.expr("""filter(
                    transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
                    x -> x is not null
                )""")
    )
)

df1.show()
#+-------+--------------------+--------------------+----------------+
#|country|             sources|   infer_from_source|inferred_country|
#+-------+--------------------+--------------------+----------------+
#|   null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...|      [CZE, FRA]|
#|    DEU|               [DEU]|             [FALSE]|           [DEU]|
#+-------+--------------------+--------------------+----------------+

从 Spark 3+ 开始,您可以在过滤器 lambda 函数中使用索引:

df1 = df.withColumn(
    "inferred_country",
    F.when(
        F.col("country").isNotNull(),
        F.array(F.col("country"))
    ).otherwise(
        F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
    )
)