Python ValueError: n_splits=3 cannot be greater than the number of members in each class

Python ValueError: n_splits=3 cannot be greater than the number of members in each class

我正在做人脸识别项目,我有两个人,每个人有 2 张脸

1. personA
    image1.jpg
    image2.jpg


2. personB
    image1.jpg
    image2.jpg

我正尝试在上述数据集的面部嵌入上训练模型,如下所示:

params = {"C": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], "gamma": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]}
model = GridSearchCV(SVC(kernel="rbf", gamma="auto", probability=True), params, cv=3, n_jobs=-1)
model.fit(data["embeddings"], labels)

其中 data["embeddings"]labels 的长度是 4data["embeddings']包含personA,personB的face embedding的ndarray

data['embeddings'] = [
                         [0.02331057, -0.01995077, ..], 
                         [-0.00034041,  0.02753334, ..], 
                         [0.02454563, -0.03797123, ...], 
                         [0.10561685, -0.08444008, ...]
                     ]

labels = [0 0 1 1]

但我在 model.fit(data["embeddings"], labels):

遇到以下错误
ValueError: n_splits=3 cannot be greater than the number of members in each class.

我无法理解这个错误。谁能解释一下这个问题,我该如何解决?

仔细阅读,错误信息清晰且不言自明;它只是告诉您,由于每个 类 总共只有两 (2) 个样本,因此您无法进行 3 折交叉验证。这将需要 至少 每个 类.

3 个样本

我想它应该可以与 cv=2 一起使用而不会引发任何错误,但是您的整个方法(即只有 4 个样本的数据集)似乎非常值得怀疑。