如何为 StreamingLogisticRegressionWithSGD 配置分类编号 classes/labels
How can I configure no of classification classes/labels for StreamingLogisticRegressionWithSGD
我是 Spark MLlib 的新手。我正在尝试实施 StreamingLogisticRegressionWithSGD 模型。 Spark 文档中提供的信息很少。当我在套接字流上输入 2,22-22-22
时,我得到
ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels
我知道它希望我输入标签为 0 或 1 的特征,但我真的很想知道我是否可以为它配置更多标签。
我不知道如何为 StreamingLogisticRegressionWithSGD 的分类设置 类 的数量。
谢谢!
代码
package test;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.StreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class SLRPOC {
private static StreamingLogisticRegressionWithSGD slrModel;
private static int numFeatures = 3;
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setMaster("local[3]").setAppName("SLRPOC");
SparkContext sc = new SparkContext(sparkConf);
StreamingContext ssc = new StreamingContext(sc, Durations.seconds(10));
JavaStreamingContext jssc = new JavaStreamingContext(ssc);
slrModel = new StreamingLogisticRegressionWithSGD().setStepSize(0.5).setNumIterations(10).setInitialWeights(Vectors.zeros(numFeatures));
slrModel.trainOn(getDStreamTraining(jssc));
slrModel.predictOn(getDStreamPrediction(jssc)).foreachRDD(new Function<JavaRDD<Double>, Void>() {
private static final long serialVersionUID = 5287086933555760190L;
@Override
public Void call(JavaRDD<Double> v1) throws Exception {
List<Double> list = v1.collect();
for (Double d : list) {
System.out.println(d);
}
return null;
}
});
jssc.start();
jssc.awaitTermination();
}
public static JavaDStream<LabeledPoint> getDStreamTraining(JavaStreamingContext context) {
JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9998);
return lines.map(new Function<String, LabeledPoint>() {
private static final long serialVersionUID = 1268686043314386060L;
@Override
public LabeledPoint call(String data) throws Exception {
System.out.println("Inside LabeledPoint call : ----- ");
String arr[] = data.split(",");
double vc[] = new double[3];
String vcS[] = arr[1].split("-");
int i = 0;
for (String vcSi : vcS) {
vc[i++] = Double.parseDouble(vcSi);
}
return new LabeledPoint(Double.parseDouble(arr[0]), Vectors.dense(vc));
}
});
}
public static JavaDStream<Vector> getDStreamPrediction(JavaStreamingContext context) {
JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9999);
return lines.map(new Function<String, Vector>() {
private static final long serialVersionUID = 1268686043314386060L;
@Override
public Vector call(String data) throws Exception {
System.out.println("Inside Vector call : ----- ");
String vcS[] = data.split("-");
double vc[] = new double[3];
int i = 0;
for (String vcSi : vcS) {
vc[i++] = Double.parseDouble(vcSi);
}
return Vectors.dense(vc);
}
});
}
}
异常
Inside LabeledPoint call : ----- 16/05/18 17:51:10 INFO Executor:
Finished task 0.0 in stage 4.0 (TID 4). 953 bytes result sent to
driver 16/05/18 17:51:10 INFO TaskSetManager: Finished task 0.0 in
stage 4.0 (TID 4) in 8 ms on localhost (1/1) 16/05/18 17:51:10 INFO
TaskSchedulerImpl: Removed TaskSet 4.0, whose tasks have all
completed, from pool 16/05/18 17:51:10 INFO DAGScheduler: ResultStage
4 (trainOn at SLRPOC.java:33) finished in 0.009 s 16/05/18 17:51:10
INFO DAGScheduler: Job 6 finished: trainOn at SLRPOC.java:33, took
0.019578 s 16/05/18 17:51:10 ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels 16/05/18 17:51:10 INFO
JobScheduler: Starting job streaming job 1463574070000 ms.1 from job
set of time 1463574070000 ms 16/05/18 17:51:10 ERROR JobScheduler:
Error running job streaming job 1463574070000 ms.0
org.apache.spark.SparkException: Input validation failed. at
org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251)
at
org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94)
at
org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40)
at scala.util.Try$.apply(Try.scala:161) at
org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207)
at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207)
at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745) Exception in thread "main"
org.apache.spark.SparkException: Input validation failed. at
org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251)
at
org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94)
at
org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40)
at
org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40)
at scala.util.Try$.apply(Try.scala:161) at
org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207)
at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207)
at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at
org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745) 16/05/18 17:51:10 INFO
StreamingContext: Invoking stop(stopGracefully=false) from shutdown
hook 16/05/18 17:51:10 INFO SparkContext: Starting job: foreachRDD at
SLRPOC.java:34 16/05/18 17:51:10 INFO DAGScheduler: Job 7 finished:
foreachRDD at SLRPOC.java:34, took 0.000020 s 16/05/18 17:51:10 INFO
JobScheduler: Finished job streaming job 1463574070000 ms.1 from job
set of time 1463574070000 ms 16/05/18 17:51:10 INFO ReceiverTracker:
Sent stop signal to all 2 receivers 16/05/18 17:51:10 INFO
ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO
ReceiverSupervisorImpl: Stopping receiver with message: Stopped by
driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called
receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl:
Deregistering receiver 1 16/05/18 17:51:10 INFO
ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO
ReceiverSupervisorImpl: Stopping receiver with message: Stopped by
driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called
receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl:
Deregistering receiver 0 16/05/18 17:51:10 ERROR ReceiverTracker:
Deregistered receiver for stream 1: Stopped by driver 16/05/18
17:51:10 INFO ReceiverSupervisorImpl: Stopped receiver 1 16/05/18
17:51:10 ERROR ReceiverTracker: Deregistered receiver for stream 0:
Stopped by driver
不确定您是否已经解决了这个问题,但是您使用的二进制算法只允许 2 个分类,0 或 1。如果您想要更多,则需要使用多重分类算法
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
new LogisticRegressionWithLBFGS().setNumClasses(10)
我是 Spark MLlib 的新手。我正在尝试实施 StreamingLogisticRegressionWithSGD 模型。 Spark 文档中提供的信息很少。当我在套接字流上输入 2,22-22-22
时,我得到
ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels
我知道它希望我输入标签为 0 或 1 的特征,但我真的很想知道我是否可以为它配置更多标签。 我不知道如何为 StreamingLogisticRegressionWithSGD 的分类设置 类 的数量。
谢谢!
代码
package test;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.StreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class SLRPOC {
private static StreamingLogisticRegressionWithSGD slrModel;
private static int numFeatures = 3;
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setMaster("local[3]").setAppName("SLRPOC");
SparkContext sc = new SparkContext(sparkConf);
StreamingContext ssc = new StreamingContext(sc, Durations.seconds(10));
JavaStreamingContext jssc = new JavaStreamingContext(ssc);
slrModel = new StreamingLogisticRegressionWithSGD().setStepSize(0.5).setNumIterations(10).setInitialWeights(Vectors.zeros(numFeatures));
slrModel.trainOn(getDStreamTraining(jssc));
slrModel.predictOn(getDStreamPrediction(jssc)).foreachRDD(new Function<JavaRDD<Double>, Void>() {
private static final long serialVersionUID = 5287086933555760190L;
@Override
public Void call(JavaRDD<Double> v1) throws Exception {
List<Double> list = v1.collect();
for (Double d : list) {
System.out.println(d);
}
return null;
}
});
jssc.start();
jssc.awaitTermination();
}
public static JavaDStream<LabeledPoint> getDStreamTraining(JavaStreamingContext context) {
JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9998);
return lines.map(new Function<String, LabeledPoint>() {
private static final long serialVersionUID = 1268686043314386060L;
@Override
public LabeledPoint call(String data) throws Exception {
System.out.println("Inside LabeledPoint call : ----- ");
String arr[] = data.split(",");
double vc[] = new double[3];
String vcS[] = arr[1].split("-");
int i = 0;
for (String vcSi : vcS) {
vc[i++] = Double.parseDouble(vcSi);
}
return new LabeledPoint(Double.parseDouble(arr[0]), Vectors.dense(vc));
}
});
}
public static JavaDStream<Vector> getDStreamPrediction(JavaStreamingContext context) {
JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9999);
return lines.map(new Function<String, Vector>() {
private static final long serialVersionUID = 1268686043314386060L;
@Override
public Vector call(String data) throws Exception {
System.out.println("Inside Vector call : ----- ");
String vcS[] = data.split("-");
double vc[] = new double[3];
int i = 0;
for (String vcSi : vcS) {
vc[i++] = Double.parseDouble(vcSi);
}
return Vectors.dense(vc);
}
});
}
}
异常
Inside LabeledPoint call : ----- 16/05/18 17:51:10 INFO Executor: Finished task 0.0 in stage 4.0 (TID 4). 953 bytes result sent to driver 16/05/18 17:51:10 INFO TaskSetManager: Finished task 0.0 in stage 4.0 (TID 4) in 8 ms on localhost (1/1) 16/05/18 17:51:10 INFO TaskSchedulerImpl: Removed TaskSet 4.0, whose tasks have all completed, from pool 16/05/18 17:51:10 INFO DAGScheduler: ResultStage 4 (trainOn at SLRPOC.java:33) finished in 0.009 s 16/05/18 17:51:10 INFO DAGScheduler: Job 6 finished: trainOn at SLRPOC.java:33, took 0.019578 s 16/05/18 17:51:10 ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels 16/05/18 17:51:10 INFO JobScheduler: Starting job streaming job 1463574070000 ms.1 from job set of time 1463574070000 ms 16/05/18 17:51:10 ERROR JobScheduler: Error running job streaming job 1463574070000 ms.0 org.apache.spark.SparkException: Input validation failed. at org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at scala.util.Try$.apply(Try.scala:161) at org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Exception in thread "main" org.apache.spark.SparkException: Input validation failed. at org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:94) at org.apache.spark.mllib.regression.StreamingLinearAlgorithm$$anonfun$trainOn.apply(StreamingLinearAlgorithm.scala:92) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply$mcV$sp(ForEachDStream.scala:42) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun$$anonfun$apply$mcV$sp.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply$mcV$sp(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at org.apache.spark.streaming.dstream.ForEachDStream$$anonfun.apply(ForEachDStream.scala:40) at scala.util.Try$.apply(Try.scala:161) at org.apache.spark.streaming.scheduler.Job.run(Job.scala:34) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply$mcV$sp(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler$$anonfun$run.apply(JobScheduler.scala:207) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:57) at org.apache.spark.streaming.scheduler.JobScheduler$JobHandler.run(JobScheduler.scala:206) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) 16/05/18 17:51:10 INFO StreamingContext: Invoking stop(stopGracefully=false) from shutdown hook 16/05/18 17:51:10 INFO SparkContext: Starting job: foreachRDD at SLRPOC.java:34 16/05/18 17:51:10 INFO DAGScheduler: Job 7 finished: foreachRDD at SLRPOC.java:34, took 0.000020 s 16/05/18 17:51:10 INFO JobScheduler: Finished job streaming job 1463574070000 ms.1 from job set of time 1463574070000 ms 16/05/18 17:51:10 INFO ReceiverTracker: Sent stop signal to all 2 receivers 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopping receiver with message: Stopped by driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Deregistering receiver 1 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Received stop signal 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopping receiver with message: Stopped by driver: 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Called receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Deregistering receiver 0 16/05/18 17:51:10 ERROR ReceiverTracker: Deregistered receiver for stream 1: Stopped by driver 16/05/18 17:51:10 INFO ReceiverSupervisorImpl: Stopped receiver 1 16/05/18 17:51:10 ERROR ReceiverTracker: Deregistered receiver for stream 0: Stopped by driver
不确定您是否已经解决了这个问题,但是您使用的二进制算法只允许 2 个分类,0 或 1。如果您想要更多,则需要使用多重分类算法
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
new LogisticRegressionWithLBFGS().setNumClasses(10)