DeepLearning4j 和 DataVec 读取带标签的 csv 文件
DeepLearning4j and DataVec read csv file with label
我已经构建了一个 DL4j 项目。如果我按如下方式使用 MNIST 数据集,一切都很好:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
但是,我想切换到我自己的 csv 文件,格式如下:
A | B | C | X | Y
-------------------------
1 | 100 | 5 | 15 | 6
...
X
和 Y
是结果(或标签)。因为我打算进行回归分析,所以X
和Y
都是实数。所以我使用以下代码读取了 csv 文件:
RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
代码中的3
表示index of the labels
,2
表示number of possible labels
。这两个参数就不多解释了。我猜他们的意思是标签从第 4 列开始,有 2 个标签。
当我运行代码时,它显示以下异常:
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14
我认为是因为 dl4j 不识别 15
作为标签。
所以我的问题是:如何正确读取 csv 文件以进行回归分析?
非常感谢。
您需要将 regression true(它是构造函数的额外部分)传递给 RecordReaderDataSetIterator。
我已经构建了一个 DL4j 项目。如果我按如下方式使用 MNIST 数据集,一切都很好:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
但是,我想切换到我自己的 csv 文件,格式如下:
A | B | C | X | Y
-------------------------
1 | 100 | 5 | 15 | 6
...
X
和 Y
是结果(或标签)。因为我打算进行回归分析,所以X
和Y
都是实数。所以我使用以下代码读取了 csv 文件:
RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
代码中的3
表示index of the labels
,2
表示number of possible labels
。这两个参数就不多解释了。我猜他们的意思是标签从第 4 列开始,有 2 个标签。
当我运行代码时,它显示以下异常:
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14
我认为是因为 dl4j 不识别 15
作为标签。
所以我的问题是:如何正确读取 csv 文件以进行回归分析?
非常感谢。
您需要将 regression true(它是构造函数的额外部分)传递给 RecordReaderDataSetIterator。