合并 Spark DataFrame 中的多个列 [Java]

Merge multiple columns in a Spark DataFrame [Java]

如何将 DataFrame 中的多列(例如 3 列)合并到单列(在新的 DataFrame 中),其中每一行都变成 Spark DenseVector?类似于此 ,但在 Java 中,并进行了下面提到的一些调整。

我试过像这样使用 UDF:

private UDF3<Double, Double, Double, Row> toColumn = new UDF3<Double, Double, Double, Row>() {

    private static final long serialVersionUID = 1L;

    public Row call(Double first, Double second, Double third) throws Exception {           
        Row row = RowFactory.create(Vectors.dense(first, second, third));

        return row; 
    }
};

然后注册UDF:

sqlContext.udf().register("toColumn", toColumn, dataType);

其中 dataType 是:

StructType dataType = DataTypes.createStructType(new StructField[]{
    new StructField("bla", new VectorUDT(), false, Metadata.empty()),
    });

当我在具有 3 列的 DataFrame 上调用此 UDF 并打印出新 DataFrame 的架构时,我得到:

root |-- features: struct (nullable = true) | |-- bla: vector (nullable = false)

这里的问题是我需要一个向量在外部,而不是在结构中。 像这样:

root
 |-- features: vector (nullable = true)

我不知道如何得到这个,因为 register 函数要求 return 类型的 UDF 是 DataType(反过来,它不提供矢量类型)

您实际上使用此数据类型将向量类型手动嵌套到结构中:

new StructField("bla", new VectorUDT(), false, Metadata.empty()),

如果你去掉外面的StructField,你会得到你想要的。当然,在这种情况下,您需要稍微修改一下函数定义的签名。也就是说,您需要 return 类型 Vector。

请在下面以简单的 JUnit 测试形式查看我的具体示例。

package sample.spark.test;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.junit.Test;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class ToVectorTest implements Serializable {
  private static final long serialVersionUID = 2L;

  private UDF3<Double, Double, Double, Vector> toColumn = new UDF3<Double, Double, Double, Vector>() {

    private static final long serialVersionUID = 1L;

    public Vector call(Double first, Double second, Double third) throws Exception {
      return Vectors.dense(first, second, third);
    }
  };

  @Test
  public void testUDF() {
    // context
    final JavaSparkContext sc = new JavaSparkContext("local", "ToVectorTest");
    final SQLContext sqlContext = new SQLContext(sc);

    // test input
    final DataFrame input = sqlContext.createDataFrame(
        sc.parallelize(
            Arrays.asList(
                RowFactory.create(1.0, 2.0, 3.0),
                RowFactory.create(4.0, 5.0, 6.0),
                RowFactory.create(7.0, 8.0, 9.0),
                RowFactory.create(10.0, 11.0, 12.0)
            )),
        DataTypes.createStructType(
            Arrays.asList(
                new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("feature3", DataTypes.DoubleType, false, Metadata.empty())
            )
        )
    );
    input.registerTempTable("input");

    // expected output
    final Set<Vector> expectedOutput = new HashSet<>(Arrays.asList(
        Vectors.dense(1.0, 2.0, 3.0),
        Vectors.dense(4.0, 5.0, 6.0),
        Vectors.dense(7.0, 8.0, 9.0),
        Vectors.dense(10.0, 11.0, 12.0)
    ));

    // processing
    sqlContext.udf().register("toColumn", toColumn, new VectorUDT());
    final DataFrame outputDF = sqlContext.sql("SELECT toColumn(feature1, feature2, feature3) AS x FROM input");
    final Set<Vector> output = new HashSet<>(outputDF.toJavaRDD().map(r -> r.<Vector>getAs("x")).collect());

    // evaluation
    assertEquals(expectedOutput.size(), output.size());
    for (Vector x : output) {
      assertTrue(expectedOutput.contains(x));
    }

    // show the schema and the content
    System.out.println(outputDF.schema());
    outputDF.show();

    sc.stop();
  }
}