交叉验证:查找不属于训练集的测试集的行索引

Cross-validation: finding row indices for a test set that aren't part of a training set

我需要做的是从 numpy 矩阵中随机挑选(替换)50 行,用于训练线性分隔符。

然后,我需要使用我没有选择的行来测试线性分隔符。

对于第一部分,A 是我的完整数据矩阵,我这样做:

A_train = A[np.random.randint(A.shape[0],size=50),:]

但是我目前没有有效的方法找到:

A_test = ...

其中 A_test 不包含与 A_train 相同的行。我该怎么做?

这个问题的关键是 A 是一个 n x m 矩阵,而不是一维矩阵...

您可以使用 np.setdiff1d 查找未包含在您的训练集中的行索引:

import numpy as np

gen = np.random.RandomState(0)

n_total = 1000
n_train = 800

train_idx = gen.choice(n_total, size=n_train)
test_idx = np.setdiff1d(np.arange(n_total), train_idx)

有放回抽样的一个结果是,有资格包含在测试集中的示例数量将根据训练集中重复示例的数量而变化:

print(test_idx.size)
# 439

如果您想确保测试集的大小一致,您可以从不在训练集中的索引集中进行替换重采样:

n_test = 200
test_idx2 = gen.choice(test_idx, size=n_test)

如果您实际上并不关心有放回抽样,那么一个更简单的选择是生成所有索引的随机排列,然后将前 N 个作为训练示例,其余作为测试示例:

idx = gen.permutation(n_total)
train_idx, test_idx = idx[:n_train], idx[n_train:]

或者您可以使用 np.random.shuffle.

就地随机排列数组的行

我还应该指出,scikit-learn 具有 various convenience methods 用于将数据划分为训练集和测试集以进行交叉验证。