scikit-learn 中的 StratifiedKFold 与 KFold

StratifiedKFold vs KFold in scikit-learn

我用这段代码来测试 KFoldStratifiedKFold

import numpy as np
from sklearn.model_selection import KFold,StratifiedKFold

X = np.array([

y = np.array([0,0,0,0,1,1,1,1])

sfolder = StratifiedKFold(n_splits=4,random_state=0,shuffle=False)
floder = KFold(n_splits=4,random_state=0,shuffle=False)

for train, test in sfolder.split(X,y):
    print('Train: %s | test: %s' % (train, test))
print("StratifiedKFold done")

for train, test in floder.split(X,y):
    print('Train: %s | test: %s' % (train, test))
print("KFold done")


Train: [1 2 3 5 6 7] | test: [0 4]
Train: [0 2 3 4 6 7] | test: [1 5]
Train: [0 1 3 4 5 7] | test: [2 6]
Train: [0 1 2 4 5 6] | test: [3 7]
StratifiedKFold done
Train: [2 3 4 5 6 7] | test: [0 1]
Train: [0 1 4 5 6 7] | test: [2 3]
Train: [0 1 2 3 6 7] | test: [4 5]
Train: [0 1 2 3 4 5] | test: [6 7]
KFold done


何时使用 KFold 而不是 StratifiedKFold

StratifiedKFold: 这个交叉验证对象是 returns 分层折叠的 KFold 的变体。通过保留每个 class


KFold: 将数据集拆分成 k 个连续的折叠。

StratifiedKFold 在训练和测试中需要 balance of percentage each class 时使用。如果不需要,则使用 KFOld

Assume Classification problem, Having 3 class(A,B,C) to predict.

Class  No_of_instance

 A           50 
 B           50
 C           50


If data-set is  divided  into 5 fold. Then each fold will contains 10 instance from each class, i.e. no of instance per class is equal and follow  uniform distribution.


it will randomly took 30 instance and no of instance per class may or may not be equal or uniform.

**When to use**

Classification task use StratifiedKFold, and regression task use Kfold .
But if dataset contains  large number of instance, both StratifiedKFold and Kfold can be used in classification task.

我想你应该问“什么时候使用 StratifiedKFold 而不是 KFold?”。


KFold is a cross-validator that divides the dataset into k folds.

Stratified is to ensure that each fold of dataset has the same proportion of observations with a given label.





假设有一个包含 16 个数据点且 class 分布不平衡的数据集。在数据集中,12个数据点属于class A,其余(即4个)属于class B。class B与class A的比率为1 /3。如果我们使用 StratifiedKFold 并设置 k = 4,那么在每次迭代中,训练集将包括来自 class A 的 9 个数据点和来自 class 的 3 个数据点] B 而测试集包括来自 class A 的 3 个数据点和来自 class B.

的 1 个数据点

正如我们所见,StratifiedKFold 的拆分保留了数据集的 class 分布,而 KFold不考虑这个。