Pyspark dataframe isin function 数据类型转换

Pyspark dataframe isin function datatype conversion

我正在使用 isin 函数来过滤 pyspark 数据帧。令人惊讶的是,尽管列数据类型 (double) 与列表中的数据类型 (Decimal) 不匹配,但存在匹配。谁能帮我理解为什么会这样?

例子

(Pdb) df.show(3)
+--------------------+---------+------------+
|           employee_id|threshold|wage|
+--------------------+---------+------------+
|AAA |      0.9|         0.5|      
|BBB |      0.8|         0.5|   
|CCC |      0.9|         0.5| 
+--------------------+---------+------------+

(Pdb) df.printSchema()
root
 |-- employee_id: string (nullable = true)
 |-- threshold: double (nullable = true)
 |-- wage: double (nullable = true)

(Pdb) include_thresholds
[Decimal('0.8')]

(Pdb) df.count()
3267                                                                           
(Pdb) df.filter(fn.col("threshold").isin(include_thresholds)).count()
1633

但是,如果我使用普通的“in”运算符来测试0.8是否属于include_thresholds,那显然是错误的

(Pdb) 0.8 in include_thresholds
False

函数 col 或 isin 是否隐式执行数据类型转换?

在 isin 文档中找到了答案:

https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/Column.html#isin-java.lang.Object...-

是辛 public 列在(对象...列表)中 一个布尔表达式,如果此表达式的值包含在参数的评估值中,则评估为 true。 注意:由于列表中元素的类型仅在 运行 时间内被推断,因此元素将被“向上转换”为最常见的类型以进行比较。例如:1)在“Int vs String”的情况下,“Int”将被向上转换为“String”,比较看起来像“String vs String”。 2)在“Float vs Double”的情况下,“Float”将被向上转换为“Double”,比较看起来像“Double vs Double”

当你将外部输入带到spark进行比较时。它们只是被当作字符串并根据上下文向上转换。

因此,您基于 numpy 数据类型观察到的内容可能不适用于 spark。

import decimal
include_thresholds=[decimal.Decimal(0.8)]
include_thresholds2=[decimal.Decimal('0.8')]

0.8 in include_thresholds  # True
0.8 in include_thresholds2  # False

然后,记下这些值

include_thresholds

[Decimal('0.8000000000000000444089209850062616169452667236328125')]

include_thresholds2

[Decimal('0.8')]

来到数据框

df = spark.sql(""" with t1 (
 select  'AAA'  c1, 0.9 c2,   0.5 c3    union all
 select  'BBB'  c1, 0.8 c2,   0.5 c3    union all
 select  'CCC'  c1, 0.9 c2,   0.5 c3
  )  select   c1 employee_id,   cast(c2 as double)  threshold,   cast(c3 as double) wage    from t1
""")

df.show()
df.printSchema()

+-----------+---------+----+
|employee_id|threshold|wage|
+-----------+---------+----+
|        AAA|      0.9| 0.5|
|        BBB|      0.8| 0.5|
|        CCC|      0.9| 0.5|
+-----------+---------+----+

root
 |-- employee_id: string (nullable = false)
 |-- threshold: double (nullable = false)
 |-- wage: double (nullable = false)

include_thresholds2 可以正常工作。

df.filter(col("threshold").isin(include_thresholds2)).show()

+-----------+---------+----+
|employee_id|threshold|wage|
+-----------+---------+----+
|        BBB|      0.8| 0.5|
+-----------+---------+----+

现在下面抛出错误。

df.filter(col("threshold").isin(include_thresholds)).show()

org.apache.spark.sql.AnalysisException: decimal can only support precision up to 38;

因为它采用值 0.8000000000000000444089209850062616169452667236328125 并尝试向上转换并因此抛出错误。