如何将具有 Decimal 的 spark DataFrame 转换为具有相同精度的 BigDecimal 的 Dataset?

How to convert a spark DataFrame with a Decimal to a Dataset with a BigDecimal of the same precision?

如何创建具有给定精度的 BigDecimal 的 spark 数据集?请参阅 spark shell 中的以下示例。您会看到我可以创建具有所需 BigDecimal 精度的 DataFrame,但无法将其转换为数据集。

scala> import scala.collection.JavaConverters._
scala> case class BD(dec: BigDecimal)
scala> val schema = StructType(Seq(StructField("dec", DecimalType(38, 0))))
scala> val highPrecisionDf = spark.createDataFrame(List(Seq(BigDecimal("12345678901122334455667788990011122233"))).map(a => Row.fromSeq(a)).asJava, schema)
highPrecisionDf: org.apache.spark.sql.DataFrame = [dec: decimal(38,0)]
scala> highPrecisionDf.as[BD]
org.apache.spark.sql.AnalysisException: Cannot up cast `dec` from decimal(38,0) to decimal(38,18) as it may truncate
The type path of the target object is:
- field (class: "scala.math.BigDecimal", name: "dec")
- root class: "BD"
You can either add an explicit cast to the input data or choose a higher precision type of the field in the target object;

同样,我无法从 class 使用更高精度 BigDecimal 的情况创建数据集。

scala> List(BD(BigDecimal("12345678901122334455667788990011122233"))).toDS.show()
+----+
| dec|
+----+
|null|
+----+

是否有任何方法可以创建包含精度不同于默认小数 (38,18) 的 BigDecimal 字段的数据集?

我发现的一种解决方法是在数据集中使用字符串来保持精度。如果您不需要将值用作数字(例如排序或数学),则此解决方案有效。如果您需要这样做,您可以将其转换回 DataFrame,转换为适当的高精度类型,然后再转换回您的数据集。

val highPrecisionDf = spark.createDataFrame(List(Seq(BigDecimal("12345678901122334455667788990011122233"))).map(a => Row.fromSeq(a)).asJava, schema)
case class StringDecimal(dec: String)
highPrecisionDf.as[StringDecimal]

默认情况下,在 class 的情况下,spark 将推断 Decimal 类型(或 BigDecimal)的模式为 DecimalType(38, 18)(参见 org.apache.spark.sql.types.DecimalType.SYSTEM_DEFAULT

解决方法是将数据集转换为数据框,如下所示

case class TestClass(id: String, money: BigDecimal)

val testDs = spark.createDataset(Seq(
  TestClass("1", BigDecimal("22.50")),
  TestClass("2", BigDecimal("500.66"))
))

testDs.printSchema()

root
 |-- id: string (nullable = true)
 |-- money: decimal(38,18) (nullable = true)

解决方法

import org.apache.spark.sql.types.DecimalType
val testDf = testDs.toDF()

testDf
  .withColumn("money", testDf("money").cast(DecimalType(10,2)))
  .printSchema()

root
 |-- id: string (nullable = true)
 |-- money: decimal(10,2) (nullable = true)

您可以查看此 link 以获得更详细的信息 https://issues.apache.org/jira/browse/SPARK-18484)