如何在 Scala Spark MLLib 中获取 StratifiedKFold

How to get StratifiedKFold in Scala Spark MLLib

我搜索了一下,没有找到 - 写了一个可以与 Scala、Spark (MLlib) 一起使用的 StratifiedKFold 方法。我在下面发布答案

def StratifiedKFold(nSamples: Int, k: Int, labels: List[Int],shuffle: Boolean = false):  (Map[Int,List[List[Int]]],Int)= {

    var idxs = (0 until nSamples).toArray
    val unqLabels = labels.distinct
    val noOfLabels = unqLabels.length

    val idxsbylabel = idxs.groupBy { x => labels(x) }
    var stratifiedidxs: Map[Int,List[List[Int]]] = Map(1 -> List(List(1))) 

    for ( i <- 0 to noOfLabels-1){
        val labelsgroup_i_arr = if(shuffle) bshuffle(idxsbylabel(i).toArray) else idxsbylabel(i).toArray
        val noOfParts = if(labelsgroup_i_arr.length%k==0) labelsgroup_i_arr.length/k else (labelsgroup_i_arr.length/k)+1
        val labelsgroup_i_lst = List.concat(labelsgroup_i_arr)
        stratifiedidxs = stratifiedidxs + (i -> labelsgroup_i_lst.grouped(noOfParts).toList)
    }
    (stratifiedidxs,noOfLabels)
}

java 版本的 JavaRDD 的 stratifiedKFold。

    /*
stritified KFolds. this will promise each category will be fairly splited into training and testing samples.
for example 3 folds, category '0' has 3 samples, then all the 3 output folds will be like: training with 2 samples, testing with 1 samples.
the split will be as fair as possible, like: to split 11 samples into 3, result will be [4, 4, 3].
if the count of one category is less than k, then some folds will not contain samples of this category, this bad split will cause inaccurate result. but this is not mandatory.
 */
private static Tuple2<JavaRDD<LabeledPoint>, JavaRDD<LabeledPoint>>[] stritifyKFolds(JavaRDD<LabeledPoint> data, Broadcast<Integer> kBC){
    JavaRDD<List<List<LabeledPoint>>> foldedLP =  data.mapToPair( //map to (key,list) pair.
            lp -> {
                List<LabeledPoint> list = new ArrayList<>();
                list.add(lp);
                return new Tuple2<>(lp.label(), list);
            }
    ).reduceByKey( //aggregate
            (List<LabeledPoint> list1, List<LabeledPoint> list2) -> {
                list1.addAll(list2);
                return list1; //put the LabeledPoint with the same key into one list.
            }
    ).values() //get list only after aggregate.
    .map( //split each list into K folds.
            list -> {
                //shuffule and then put into different folds.
                Collections.shuffle(list);
                int total = list.size();
                if(total < kBC.value()){
                    log.warn("category size {} is less than folds number {}, this will break stratification and leads to bad folds splitting", total, kBC.value());
                }
                //assign each element into folds.
                List<List<LabeledPoint>> keyFolds = new ArrayList<>();
                int avg = total/kBC.value(); //averge number of elements.
                int remain = total - kBC.value(); //remain number of elements, which will be assign to folds from left to right fairly.
                for(int i=0, index = 0, count=0; i<kBC.value(); i++, index += count){
                    //get current folds count
                    count = (i<remain) ? avg+1 : avg;
                    keyFolds.add(list.subList(index, index+count));
                }
                return keyFolds;
            }
    );
    foldedLP.persist(StorageLevel.MEMORY_AND_DISK_SER());

    Tuple2<JavaRDD<LabeledPoint>, JavaRDD<LabeledPoint>>[] result = new Tuple2[kBC.value()];
    //initialize folds, each fold is a tuple of (training RDD, testing RDD).
    for(int i=0; i<kBC.value(); i++){
        final int ii = i; //must be final variable if used in lambda.
        JavaRDD<LabeledPoint> training_i = foldedLP.flatMap(
                list -> {
                    List<LabeledPoint> noII = new ArrayList<>();
                    for(int j=0; j<kBC.value(); j++){
                        if(j != ii){
                            noII.addAll(list.get(j)); //get the all except ii_th list iterator.
                        }
                    }
                    return noII.iterator();
                }
        );
        JavaRDD<LabeledPoint> testing_i = foldedLP.flatMap(
                list -> list.get(ii).iterator() //get the ii_th list iterator.
        );
        result[i] = new Tuple2<>(training_i, testing_i);
    }
    foldedLP.unpersist();

    return result;
}