根据给定的输入数组过滤数据框中的数组列——Pyspark

Filter array column in a dataframe based on a given input array --Pyspark

我有一个这样的数据框

Studentname  Speciality
Alex         ["Physics","Math","biology"]
Sam          ["Economics","History","Math","Physics"]
Claire       ["Political science,Physics"]

我想找到所有在 [Physics,Math] 专业的学生,​​所以输出应该有 2 行 Alex,Sam

这是我试过的

from pyspark.sql.functions import array_contains
from pyspark.sql import functions as F

def student_info():
     student_df = spark.read.parquet("s3a://studentdata")
     a1=["Physics","Math"]
     df=student_df
     for a in a1:
       df= student_df.filter(array_contains(student_df.Speciality, a))
       print(df.count())

student_info()

output:
3
2

想知道如何根据给定的数组子集过滤数组列

假设您有,学生 Speciality 中没有重复项(例如

StudentName   Speciality
SomeStudent   ['Physics', 'Math', 'Biology', 'Physics']

您可以在 pandas

中将 explodegroupby 一起使用

所以,对于你的问题

# df is above dataframe
# Lookup subjects
a1 = ['Physics', 'Math']

gdata = df.explode('Speciality').groupby(['Speciality']).size().to_frame('Count')

gdata.loc[a1, 'Count']

#             Count
# Speciality
# Physics         3
# Math            2

使用高阶函数filter应该是最可扩展高效的方法(Spark2.4 )

from pyspark.sql import functions as F
df.withColumn("new", F.size(F.expr("""filter(Speciality, x-> x=='Math' or x== 'Physics')""")))\
  .filter("new=2").drop("new").show(truncate=False)
+-----------+-----------------------------------+
|Studentname|Speciality                         |
+-----------+-----------------------------------+
|Alex       |[Physics, Math, biology]           |
|Sam        |[Economics, History, Math, Physics]|
+-----------+-----------------------------------+

如果你想使用像 a1 这样的 array 动态地 这样做,你可以使用 F.array_exceptF.array 然后 filter on size ( spark 2.4 ):

a1=['Math','Physics']
df.withColumn("array", F.array_except("Speciality",F.array(*(F.lit(x) for x in a1))))\
  .filter("size(array)= size(Speciality)-2").drop("array").show(truncate=False)

+-----------+-----------------------------------+
|Studentname|Speciality                         |
+-----------+-----------------------------------+
|Alex       |[Physics, Math, biology]           |
|Sam        |[Economics, History, Math, Physics]|
+-----------+-----------------------------------+

要获得计数,您可以输入 .count() 而不是 .show()

这是另一种利用 array_sort 和 Spark 相等运算符的方法,它像处理任何其他类型一样处理数组,前提是它们已排序:

from pyspark.sql.functions import lit, array, array_sort, array_intersect

target_ar = ["Physics", "Math"]
search_ar = array_sort(array(*[lit(e) for e in target_ar]))

df.where(array_sort(array_intersect(df["Speciality"], search_ar)) == search_ar) \
  .show(10, False)

# +-----------+-----------------------------------+
# |Studentname|Speciality                         |
# +-----------+-----------------------------------+
# |Alex       |[Physics, Math, biology]           |
# |Sam        |[Economics, History, Math, Physics]|
# +-----------+-----------------------------------+

首先我们找到与array_intersect(df["Speciality"], search_ar)相同的元素,然后我们使用==比较排序后的数组。