如何为 StreamingLogisticRegressionWithSGD 配置分类编号 classes/labels

How can I configure no of classification classes/labels for StreamingLogisticRegressionWithSGD

我是 Spark MLlib 的新手。我正在尝试实施 StreamingLogisticRegressionWithSGD 模型。 Spark 文档中提供的信息很少。当我在套接字流上输入 2,22-22-22 时,我得到

ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels

我知道它希望我输入标签为 0 或 1 的特征,但我真的很想知道我是否可以为它配置更多标签。 我不知道如何为 StreamingLogisticRegressionWithSGD 的分类设置 类 的数量。

谢谢!

代码

package test;

import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.StreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

public class SLRPOC {

    private static StreamingLogisticRegressionWithSGD slrModel;

    private static int numFeatures = 3;

    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setMaster("local[3]").setAppName("SLRPOC");
        SparkContext sc = new SparkContext(sparkConf);
        StreamingContext ssc = new StreamingContext(sc, Durations.seconds(10));
        JavaStreamingContext jssc = new JavaStreamingContext(ssc);

        slrModel = new StreamingLogisticRegressionWithSGD().setStepSize(0.5).setNumIterations(10).setInitialWeights(Vectors.zeros(numFeatures));

        slrModel.trainOn(getDStreamTraining(jssc));
        slrModel.predictOn(getDStreamPrediction(jssc)).foreachRDD(new Function<JavaRDD<Double>, Void>() {

            private static final long serialVersionUID = 5287086933555760190L;

            @Override
            public Void call(JavaRDD<Double> v1) throws Exception {
                List<Double> list = v1.collect();
                for (Double d : list) {
                    System.out.println(d);
                }
                return null;
            }
        });

        jssc.start();
        jssc.awaitTermination();
    }

    public static JavaDStream<LabeledPoint> getDStreamTraining(JavaStreamingContext context) {
        JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9998);

        return lines.map(new Function<String, LabeledPoint>() {

            private static final long serialVersionUID = 1268686043314386060L;

            @Override
            public LabeledPoint call(String data) throws Exception {
                System.out.println("Inside LabeledPoint call : ----- ");
                String arr[] = data.split(",");
                double vc[] = new double[3];
                String vcS[] = arr[1].split("-");
                int i = 0;
                for (String vcSi : vcS) {
                    vc[i++] = Double.parseDouble(vcSi);
                }
                return new LabeledPoint(Double.parseDouble(arr[0]), Vectors.dense(vc));
            }
        });
    }

    public static JavaDStream<Vector> getDStreamPrediction(JavaStreamingContext context) {
        JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9999);

        return lines.map(new Function<String, Vector>() {

            private static final long serialVersionUID = 1268686043314386060L;

            @Override
            public Vector call(String data) throws Exception {
                System.out.println("Inside Vector call : ----- ");
                String vcS[] = data.split("-");
                double vc[] = new double[3];
                int i = 0;
                for (String vcSi : vcS) {
                    vc[i++] = Double.parseDouble(vcSi);
                }
                return Vectors.dense(vc);
            }
        });
    }
}

异常

Inside LabeledPoint call : ----- 16/05/18 17:51:10 INFO Executor: Finished task 0.0 in stage 4.0 (TID 4). 953 bytes result sent to driver 16/05/18 17:51:10 INFO TaskSetManager: Finished task 0.0 in stage 4.0 (TID 4) in 8 ms on localhost (1/1) 16/05/18 17:51:10 INFO TaskSchedulerImpl: Removed TaskSet 4.0, whose tasks have all completed, from pool 16/05/18 17:51:10 INFO DAGScheduler: ResultStage 4 (trainOn at SLRPOC.java:33) finished in 0.009 s 16/05/18 17:51:10 INFO DAGScheduler: Job 6 finished: trainOn at SLRPOC.java:33, took 0.019578 s 16/05/18 17:51:10 ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels 16/05/18 17:51:10 INFO JobScheduler: Starting job streaming job 1463574070000 ms.1 from job set of time 1463574070000 ms 16/05/18 17:51:10 ERROR JobScheduler: Error running job streaming job 1463574070000 ms.0 org.apache.spark.SparkException: Input validation failed. at org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at scala.util.Try$.apply(Try.scala:161) at org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Exception in thread "main" org.apache.spark.SparkException: Input validation failed. at org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at scala.util.Try$.apply(Try.scala:161) at org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) 16/05/18 17:51:10 INFO StreamingContext: Invoking stop(stopGracefully=false) from shutdown hook 16/05/18 17:51:10 INFO SparkContext: Starting job: foreachRDD at SLRPOC.java:34 16/05/18 17:51:10 INFO DAGScheduler: Job 7 finished: foreachRDD at SLRPOC.java:34, took 0.000020 s 16/05/18 17:51:10 INFO JobScheduler: Finished job streaming job 1463574070000 ms.1 from job set of time 1463574070000 ms 16/05/18 17:51:10 INFO ReceiverTracker: Sent stop signal to all 2 receivers 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopping receiver with message: Stopped by driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Deregistering receiver 1 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopping receiver with message: Stopped by driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Deregistering receiver 0 16/05/18 17:51:10 ERROR ReceiverTracker: Deregistered receiver for stream 1: Stopped by driver 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopped receiver 1 16/05/18 17:51:10 ERROR ReceiverTracker: Deregistered receiver for stream 0: Stopped by driver

不确定您是否已经解决了这个问题,但是您使用的二进制算法只允许 2 个分类,0 或 1。如果您想要更多,则需要使用多重分类算法

import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
new LogisticRegressionWithLBFGS().setNumClasses(10)