ND4J DL4J 将数据导入拟合方法
ND4J DL4J Getting data into fit method
我有一个 INDArray 数据,它是 x 的 return 值,如下所示:
private static INDArray createDataSet(String path)throws Exception {
List<String> lines = IOUtils.readLines(new FileInputStream(path), StandardCharsets.UTF_8);
double[] position = new double[lines.size()];
double[] year = new double[lines.size()];
double[] month = new double[lines.size()];
double[] day = new double[lines.size()];
double[] close = new double[lines.size()];
int linecount = 0;
Iterator<String> it = lines.iterator();
while(it.hasNext()) {
String line = it.next();
String[] parts = line.split(",");
position[linecount] = linecount;
year[linecount] = Double.valueOf(parts[0]);
month[linecount] = Double.valueOf(parts[1]);
day[linecount] = Double.valueOf(parts[2]);
close[linecount] = Double.valueOf(parts[5]);
linecount++;
}//endloop
double[][] arr2D = new double[][] {position, year, month, day, close};
INDArray x = Nd4j.createFromArray(arr2D);
return x;
}
我正在尝试复制 csvplotter 示例并使用单个 in/out 网络执行线性回归。
如何将数组行 (0) 作为特征加载,将数组行 (4) 作为标签加载?
更多信息:
System.out.println(ds.rank());
long[] l = ds.shape();
System.out.println(l[0] + " , " + l[1] + " - " + l.length);
System.out.println(ds.length());
结果:
2,
5, 1260 -2
6300
这里是我的问题,只是为了清楚:
for (int i = 0; i < nEpochs; i++) {
net.fit(d);
}
导致各种错误,具体取决于我尝试添加数据的方式
虽然我没有得到有效的答案,但我意识到了我的问题。基于 csv 绘图仪示例中的评论,我假设 indarray 的行被传递到输入。然而,实际传递给输入的是列。
通过转置 INDArray 并添加两列,我需要网络处理数据。
INDArray ds;
ds = ds.transpose();
DataSet ddd = new DataSet();
ddd.setFeatures(ds.getColumn(0, true)); //true maintains matrix instead of vector
ddd.setLabels(ds.getColumn(4, true));
ddd.dataSetBatches(500);
System.out.println(ddd);
我的打印输出:
===========INPUT===================
[[0],
[1.0000],
[2.0000],
...,
[1257.0000],
[1258.0000],
[1259.0000]]
=================OUTPUT==================
[[540.3100],
[536.7000],
[533.3300],
...,
[1431.8199],
[1439.2200],
[1436.3800]]
虽然训练不成功,但这确实回答了我原来的问题。
我有一个 INDArray 数据,它是 x 的 return 值,如下所示:
private static INDArray createDataSet(String path)throws Exception {
List<String> lines = IOUtils.readLines(new FileInputStream(path), StandardCharsets.UTF_8);
double[] position = new double[lines.size()];
double[] year = new double[lines.size()];
double[] month = new double[lines.size()];
double[] day = new double[lines.size()];
double[] close = new double[lines.size()];
int linecount = 0;
Iterator<String> it = lines.iterator();
while(it.hasNext()) {
String line = it.next();
String[] parts = line.split(",");
position[linecount] = linecount;
year[linecount] = Double.valueOf(parts[0]);
month[linecount] = Double.valueOf(parts[1]);
day[linecount] = Double.valueOf(parts[2]);
close[linecount] = Double.valueOf(parts[5]);
linecount++;
}//endloop
double[][] arr2D = new double[][] {position, year, month, day, close};
INDArray x = Nd4j.createFromArray(arr2D);
return x;
}
我正在尝试复制 csvplotter 示例并使用单个 in/out 网络执行线性回归。
如何将数组行 (0) 作为特征加载,将数组行 (4) 作为标签加载?
更多信息:
System.out.println(ds.rank());
long[] l = ds.shape();
System.out.println(l[0] + " , " + l[1] + " - " + l.length);
System.out.println(ds.length());
结果:
2,
5, 1260 -2
6300
这里是我的问题,只是为了清楚:
for (int i = 0; i < nEpochs; i++) {
net.fit(d);
}
导致各种错误,具体取决于我尝试添加数据的方式
虽然我没有得到有效的答案,但我意识到了我的问题。基于 csv 绘图仪示例中的评论,我假设 indarray 的行被传递到输入。然而,实际传递给输入的是列。
通过转置 INDArray 并添加两列,我需要网络处理数据。
INDArray ds;
ds = ds.transpose();
DataSet ddd = new DataSet();
ddd.setFeatures(ds.getColumn(0, true)); //true maintains matrix instead of vector
ddd.setLabels(ds.getColumn(4, true));
ddd.dataSetBatches(500);
System.out.println(ddd);
我的打印输出:
===========INPUT===================
[[0],
[1.0000],
[2.0000],
...,
[1257.0000],
[1258.0000],
[1259.0000]]
=================OUTPUT==================
[[540.3100],
[536.7000],
[533.3300],
...,
[1431.8199],
[1439.2200],
[1436.3800]]
虽然训练不成功,但这确实回答了我原来的问题。