如何将 DataSetIterator 拆分为测试和训练迭代器?

How to split a DataSetIterator into testing and training iterator?

我正在使用 Deeplearning4j 和 datavec,我有一个 DataSetIterator 对象代表我所有的数据,它是一个时间序列。我如何将其拆分为训练和测试迭代器?我检查了一下,DataSetIterator Class 的方法已被弃用。谢谢。

遍历您的 DataSetIterator 并为每个 DataSet 条目创建两个新的 DataSets,每个用于训练和测试。

关键是使用 splitTestAndTrain 方法,它接受一个 double fractionTrain 来指定要训练的数据量(其余要测试)。该方法有不同的重载,因此您可以选择最适合您需要的一种。如果您希望将所有训练和测试数据集添加到一个公共迭代器中,您可以将它们存储在两个不同的列表中,稍后再获取它们对应的迭代器。类似于:

List<DataSet> trainList = new ArrayList<>();
List<DataSet> testList= new ArrayList<>();

while (yourDataSetIterator.hasNext())
{
    DataSet ds = yourDataSetIterator.next();
    SplitTestAndTrain splData = ds.splitTestAndTrain(0.5); //half for each         
    DataSet trainDs = splData.getTrain();
    trainList.add(trainDs);
    DataSet testDs  = splData.getTest();
    testList.add(testDs);
    (...)
}

Iterator<DataSet> trainIterator = trainList.iterator(); 
Iterator<DataSet> testIterator  = testList.iterator(); 

因为我不太了解这个库的具体细节,所以这个例子只是创建了“基本”iterators。这可能是自定义的,因此您创建 DataSetIterators

请注意,您可能还需要在拆分数据集之前对其进行洗牌 (ds.shuffle())。你可以找到一些例子 here


如果你想以特定的方式拆分它,你可以标记不同的条目并找到测试数据集的最大索引;然后,调用 splitTestAndTrain(int max) method, which specifically splits the dataset regarding the max parameter. The sortByLabel 方法在这里也有帮助。


Adam Gibson gave a great suggestion on the comments regarding other mechanism in order to split the DataSetIterator, which also seem to be a "more natural" way to do it, the DataSetIteratorSplitter.

它提供 getTrainIterator()getTestIterator() 方法,return 库的特定迭代器,DataSetIterator