使二进制堆叠示例适应多类
Adapting binary stacking example to multiclass
我一直在学习this example of stacking。在这种情况下,每组 K-folds 产生一列数据,并且对每个分类器重复这一过程。即:混合矩阵是:
dataset_blend_train = np.zeros((X.shape[0], len(clfs)))
dataset_blend_test = np.zeros((X_submission.shape[0], len(clfs)))
我需要堆叠来自多类问题的预测(每个样本可能有 15 个不同 类)。这将为每个 clf 生成一个 n*15 矩阵。
这些矩阵是否应该水平连接?还是应该在应用逻辑回归之前以其他方式组合它们?谢谢。
您可以通过两种方式使代码适应多 class 问题:
- 横向连接概率,即您需要创建:
dataset_blend_train = np.zeros((X.shape[0], len(clfs)*numOfClasses))
dataset_blend_test = np.zeros((X_submission.shape[0], len(clfs)*numOfClasses))
- 不使用概率,而是对基础模型使用 class 预测。这样你可以保持数组的大小相同,但你只使用
predict
而不是 predict_proba
。
我都成功使用过,但哪个效果更好可能取决于数据集。
当您循环遍历每个分类器时,还存在扩展功能的问题。我使用以下内容:
db_train = np.zeros((X_train.shape[0], np.unique(y).shape[0]))
db_test = clf.predict_proba(X_test)
...
try:
dataset_blend_train
except NameError:
dataset_blend_train = db_train
else:
dataset_blend_train = np.hstack((dataset_blend_train, db_train))
try:
dataset_blend_test
except NameError:
dataset_blend_test = db_test
else:
dataset_blend_test = np.hstack((dataset_blend_test, db_test))
我一直在学习this example of stacking。在这种情况下,每组 K-folds 产生一列数据,并且对每个分类器重复这一过程。即:混合矩阵是:
dataset_blend_train = np.zeros((X.shape[0], len(clfs)))
dataset_blend_test = np.zeros((X_submission.shape[0], len(clfs)))
我需要堆叠来自多类问题的预测(每个样本可能有 15 个不同 类)。这将为每个 clf 生成一个 n*15 矩阵。
这些矩阵是否应该水平连接?还是应该在应用逻辑回归之前以其他方式组合它们?谢谢。
您可以通过两种方式使代码适应多 class 问题:
- 横向连接概率,即您需要创建:
dataset_blend_train = np.zeros((X.shape[0], len(clfs)*numOfClasses))
dataset_blend_test = np.zeros((X_submission.shape[0], len(clfs)*numOfClasses))
- 不使用概率,而是对基础模型使用 class 预测。这样你可以保持数组的大小相同,但你只使用
predict
而不是predict_proba
。
我都成功使用过,但哪个效果更好可能取决于数据集。
当您循环遍历每个分类器时,还存在扩展功能的问题。我使用以下内容:
db_train = np.zeros((X_train.shape[0], np.unique(y).shape[0]))
db_test = clf.predict_proba(X_test)
...
try:
dataset_blend_train
except NameError:
dataset_blend_train = db_train
else:
dataset_blend_train = np.hstack((dataset_blend_train, db_train))
try:
dataset_blend_test
except NameError:
dataset_blend_test = db_test
else:
dataset_blend_test = np.hstack((dataset_blend_test, db_test))