Deeplearning4j 将模型解析为 DataSet
Deeplearning4j parsing model into a DataSet
Deeplearning4j 的官方指南展示了如何使用 .csv 文件,但我想知道如何使用我的自定义模型。我尝试寻找合适的 DataSet 实现,但似乎找不到。即使它采用普通 .csv 的内容(字符串格式)也足够了。我试过这样做:
型号:
package com.example.kamil.deeplearningandroid;
public class Job implements LearnableModel {
private int type;
private int salary;
private int choice;
public Job(String type, int salary, boolean choice) {
this.type = encodeType(type);
this.salary = salary;
this.choice = encodeChoice(choice);
}
private int encodeType(String job) {
switch (job) {
case "Mechanic": return 0;
case "Programmer": return 1;
case "Teacher": return 2;
case "Driver": return 3;
case "Cook": return 4;
default: return 5;
}
}
private int encodeChoice(boolean choice) {
return choice ? 1: 0;
}
@Override
public String toString() {
return type + SEPARATOR + salary + SEPARATOR + choice + "\n";
}
}
并且在 JobClassifier 中:
private DataSet readStringDataset(List<LearnableModel> data, int batchSize, int labelIndex, int numClasses) throws IOException, InterruptedException {
RecordReader rr = new LineRecordReader();
rr.initialize(new StringSplit(modelToString(data)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);
return iterator.next();
}
private String modelToString(List<LearnableModel> list) {
StringBuilder sb = new StringBuilder();
for (LearnableModel model: list) {
sb.append(model.toString());
}
return sb.toString();
}
有了这一切,我得到了:
W/System.err: java.lang.NumberFormatException: Invalid double: "1,10,0
W/System.err: 1,15,1
W/System.err: 4,7,0
W/System.err: 5,10,1
W/System.err: 3,10,0
W/System.err: 3,20,0
W/System.err: 4,5,0
W/System.err: 4,12,1
W/System.err: 2,20,1
W/System.err: 2,4,0
W/System.err: 5,12,1
W/System.err: 0,10,0
W/System.err: 5,5,0
W/System.err: 1,10,0
W/System.err: 2,16,1
W/System.err: 3,30,1
W/System.err: 4,16,1
W/System.err: 5,19,1
W/System.err: 5,6,0
W/System.err: 1,11,0"
W/System.err: at java.lang.StringToReal.invalidReal(StringToReal.java:63)
W/System.err: at java.lang.StringToReal.initialParse(StringToReal.java:164)
W/System.err: at java.lang.StringToReal.parseDouble(StringToReal.java:282)
W/System.err: at java.lang.Double.parseDouble(Double.java:301)
W/System.err: at org.datavec.api.writable.Text.toDouble(Text.java:601)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:271)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:177)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:372)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52)
W/System.err: at com.example.kamil.deeplearningandroid.JobClassifier.readStringDataset(JobClassifier.java:185)
W/System.err: at com.example.kamil.deeplearningandroid.JobClassifier.classify(JobClassifier.java:65)
W/System.err: at com.example.kamil.deeplearningandroid.MainActivity.onCreate(MainActivity.java:23)
W/System.err: at android.app.Activity.performCreate(Activity.java:6251)
W/System.err: at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1107)
W/System.err: at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2369)
W/System.err: at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:2476)
W/System.err: at android.app.ActivityThread.-wrap11(ActivityThread.java)
W/System.err: at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1344)
W/System.err: at android.os.Handler.dispatchMessage(Handler.java:102)
W/System.err: at android.os.Looper.loop(Looper.java:148)
W/System.err: at android.app.ActivityThread.main(ActivityThread.java:5417)
W/System.err: at java.lang.reflect.Method.invoke(Native Method)
W/System.err: at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:726)
W/System.err: at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:616)
您应该使用 datavec。没有“数据集实现” 一切都转换为 ndarrays。
我们的示例不仅仅涵盖了这一点:http://github.com/deeplearning4j/dl4j-examples
编辑:用于推断简单的 csv。你可以做一个简单的:
字符串[] arr = line.split(",");
通过 Float.parseFloat 和 Double.parseDouble()
创建 double[] 或 float[]
然后做:
INDArray arr = Nd4j.create(float[]);或 INDArray arr = Nd4j.create(double[]);
您不需要 inference/scoring 的数据集,仅用于训练,然后您可以将 datavec 与 RecordReaderDataSetIterator 或 SequenceRecordReaderDataSetIterator 一起用于时间序列。
Deeplearning4j 的官方指南展示了如何使用 .csv 文件,但我想知道如何使用我的自定义模型。我尝试寻找合适的 DataSet 实现,但似乎找不到。即使它采用普通 .csv 的内容(字符串格式)也足够了。我试过这样做:
型号:
package com.example.kamil.deeplearningandroid;
public class Job implements LearnableModel {
private int type;
private int salary;
private int choice;
public Job(String type, int salary, boolean choice) {
this.type = encodeType(type);
this.salary = salary;
this.choice = encodeChoice(choice);
}
private int encodeType(String job) {
switch (job) {
case "Mechanic": return 0;
case "Programmer": return 1;
case "Teacher": return 2;
case "Driver": return 3;
case "Cook": return 4;
default: return 5;
}
}
private int encodeChoice(boolean choice) {
return choice ? 1: 0;
}
@Override
public String toString() {
return type + SEPARATOR + salary + SEPARATOR + choice + "\n";
}
}
并且在 JobClassifier 中:
private DataSet readStringDataset(List<LearnableModel> data, int batchSize, int labelIndex, int numClasses) throws IOException, InterruptedException {
RecordReader rr = new LineRecordReader();
rr.initialize(new StringSplit(modelToString(data)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses);
return iterator.next();
}
private String modelToString(List<LearnableModel> list) {
StringBuilder sb = new StringBuilder();
for (LearnableModel model: list) {
sb.append(model.toString());
}
return sb.toString();
}
有了这一切,我得到了:
W/System.err: java.lang.NumberFormatException: Invalid double: "1,10,0
W/System.err: 1,15,1
W/System.err: 4,7,0
W/System.err: 5,10,1
W/System.err: 3,10,0
W/System.err: 3,20,0
W/System.err: 4,5,0
W/System.err: 4,12,1
W/System.err: 2,20,1
W/System.err: 2,4,0
W/System.err: 5,12,1
W/System.err: 0,10,0
W/System.err: 5,5,0
W/System.err: 1,10,0
W/System.err: 2,16,1
W/System.err: 3,30,1
W/System.err: 4,16,1
W/System.err: 5,19,1
W/System.err: 5,6,0
W/System.err: 1,11,0"
W/System.err: at java.lang.StringToReal.invalidReal(StringToReal.java:63)
W/System.err: at java.lang.StringToReal.initialParse(StringToReal.java:164)
W/System.err: at java.lang.StringToReal.parseDouble(StringToReal.java:282)
W/System.err: at java.lang.Double.parseDouble(Double.java:301)
W/System.err: at org.datavec.api.writable.Text.toDouble(Text.java:601)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:271)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:177)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:372)
W/System.err: at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52)
W/System.err: at com.example.kamil.deeplearningandroid.JobClassifier.readStringDataset(JobClassifier.java:185)
W/System.err: at com.example.kamil.deeplearningandroid.JobClassifier.classify(JobClassifier.java:65)
W/System.err: at com.example.kamil.deeplearningandroid.MainActivity.onCreate(MainActivity.java:23)
W/System.err: at android.app.Activity.performCreate(Activity.java:6251)
W/System.err: at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1107)
W/System.err: at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2369)
W/System.err: at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:2476)
W/System.err: at android.app.ActivityThread.-wrap11(ActivityThread.java)
W/System.err: at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1344)
W/System.err: at android.os.Handler.dispatchMessage(Handler.java:102)
W/System.err: at android.os.Looper.loop(Looper.java:148)
W/System.err: at android.app.ActivityThread.main(ActivityThread.java:5417)
W/System.err: at java.lang.reflect.Method.invoke(Native Method)
W/System.err: at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:726)
W/System.err: at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:616)
您应该使用 datavec。没有“数据集实现” 一切都转换为 ndarrays。
我们的示例不仅仅涵盖了这一点:http://github.com/deeplearning4j/dl4j-examples
编辑:用于推断简单的 csv。你可以做一个简单的: 字符串[] arr = line.split(",");
通过 Float.parseFloat 和 Double.parseDouble()
创建 double[] 或 float[]然后做: INDArray arr = Nd4j.create(float[]);或 INDArray arr = Nd4j.create(double[]);
您不需要 inference/scoring 的数据集,仅用于训练,然后您可以将 datavec 与 RecordReaderDataSetIterator 或 SequenceRecordReaderDataSetIterator 一起用于时间序列。