声明 Dataframe 的 StructType:包含 org.apache.spark.ml.linalg.Vector 的列

Declare StructType of a Dataframe: column containing org.apache.spark.ml.linalg.Vector

我有一个名为 df1 的 DataFrame,其方案如下:

root
 |-- instances: string (nullable = true)
 |-- features: vector (nullable = true)
 |-- label: double (nullable = false)

其中 featureslabel 是从 LabeledPoint 获得的。 我想生成一个新的 DataFrame,但要修改 instancesfeatures 的内容。 为此,我写了以下内容:

val schema2 = new StructType()
  .add("instances", "string")
  .add("features", "org.apache.spark.ml.linalg.Vector")  // also I've tried using `vector`
  .add("label", "double")

val schemaEncoder = RowEncoder(schema2)

val df2 = df1.map {
  case Row(inst: String, lp: org.apache.spark.ml.linalg.Vector, lbl: Double) => {
    val nInst = modifyInstances(inst)
    val nLP = nInst.split(",")
    Row(nInst, nLP, lbl)
  }
}(schemaEncoder)

当我运行代码时,问题会在.add("features", "org.apache.spark.ml.linalg.Vector")

有什么建议吗?

您需要将其 DataType 指定为 org.apache.spark.ml.linalg.SQLDataTypes.VectorType,如下所示:

import org.apache.spark.sql.types._
import org.apache.spark.ml.linalg.SQLDataTypes._

val schema2 = new StructType().
  add("instances", StringType).
  add("features", VectorType).
  add("label", DoubleType)
// schema2: org.apache.spark.sql.types.StructType = StructType(
//   StructField(instances,StringType,true),
//   StructField(features,org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7,true),
//   StructField(label,DoubleType,true)
// )