从 Python SKlearn 中的 KFold 拆分中排除某些索引

Exclude certain indices from a KFold split in Python SKlearn

我正在使用 SKlearn KFold 如下:

        kf = KFold(10000, n_folds=5, shuffle=True, random_state=88)

但是,我想(仅)从训练折叠中排除某些索引。如何实现?谢谢。

不知是否可以通过sklearn.cross_validation.PredefinedSplit实现?


更新:KFold 实例将与 XGBoost 一起用于 xgb.cv 的 folds 参数。 Python API here 指出折叠应该是 "a KFold or StratifiedKFold instance"。

但是,我会尝试像上面那样生成 KFolds,迭代训练折叠索引,修改它们,然后像​​这样手动定义 custom_cv:

custom_cv = zip(train_indices, test_indices)

如果您想从训练集中删除索引,但如果它们在测试集中就可以,那么这种方法可行:

kf_list = list(kf)

这将 return 一个元组列表,可以用与 KFold 实例相同的方式对其进行迭代。然后,您可以根据需要简单地修改索引,您的 KFold 实例将保持不变。您可以将 KFold 对象视为一个整数数组,代表索引和让您动态生成折叠的方法。

这是迭代器协议实现方式的重要部分的源代码,非常简单明了:

https://github.com/scikit-learn/scikit-learn/blob/51a765a/sklearn/cross_validation.py#L254

def _iter_test_indices(self):
    n = self.n
    n_folds = self.n_folds
    fold_sizes = (n // n_folds) * np.ones(n_folds, dtype=np.int)
    fold_sizes[:n % n_folds] += 1
    current = 0
    for fold_size in fold_sizes:
        start, stop = current, current + fold_size
        yield self.idxs[start:stop]
        current = stop