LightGBM 使用 pred_contrib=True 预测多类:返回数组中 SHAP 值的顺序

LightGBM predict with pred_contrib=True for multiclass: order of SHAP values in the returned array

LightGBM predict 方法 pred_contrib=True returns shape =(n_samples, (n_features + 1) * n_classes).

数组

这个数组第二维的数据顺序是什么?

也就是说有两个问题:

  1. 重塑此数组以使用它的正确方法是什么:shape = (n_samples, n_features + 1, n_classes)shape = (n_samples, n_classes, n_features + 1)
  2. 在特征维度中,有 n_features 个条目,每个特征一个,一个(无用的)条目表示与任何特征无关的贡献。这些条目的顺序是什么:条目 1、...、n_features 中的特征贡献与它们在数据集中出现的顺序相同,其余(无用的)条目位于索引 0 或其他方式?

答案如下:

  1. 正确的形状是(n_samples, n_classes, n_features + 1)
  2. 特征贡献在条目 1,...,n_features 中的顺序与它们在数据集中出现的顺序相同,其余(无用)条目位于索引 0。

下面的代码令人信服地展示了它:

import lightgbm, pandas, numpy
params = {'objective': 'multiclass', 'num_classes': 4, 'num_iterations': 10000,
          'metric': 'multiclass', 'early_stopping_rounds': 10}
train_df = pandas.DataFrame({'f0': [0, 1, 2, 3] * 50, 'f1': [0, 0, 1] * 66 + [1, 2]}, dtype=float)
val_df = train_df.copy()
train_target = pandas.Series([0, 1, 2, 3] * 50)
val_target = pandas.Series([0, 1, 2, 3] * 50)
train_set = lightgbm.Dataset(train_df, train_target)
val_set = lightgbm.Dataset(val_df, val_target)
model = lightgbm.train(params=params, train_set=train_set, valid_sets=[val_set, train_set])
feature_contribs = model.predict(val_df, pred_contrib=True)
print('Shape of SHAP:', feature_contribs.shape)
# Shape of SHAP: (200, 12)
print('Averages over samples:', numpy.mean(feature_contribs, axis=0))
# Averages over samples: [ 3.99942301e-13 -4.02281771e-13 -4.30029167e+00 -1.90606677e-05
#  1.90606677e-05 -4.04157656e+00  2.24205077e-05 -2.24205077e-05
#  -4.04265615e+00 -3.70370401e-15  5.20335728e-18 -4.30029167e+00]
feature_contribs.shape = (200, 4, 3)
print('Mean feature contribs:', numpy.mean(feature_contribs, axis=(0, 1)))
# Mean feature contribs: [ 8.39960111e-07 -8.39960113e-07 -4.17120401e+00]

(每个输出在下一行中显示为注释。)

解释如下

我创建了一个包含两个特征的数据集,标签与第二个特征相同。

我希望只有第二个功能能做出重大贡献。

对样本的 SHAP 输出进行平均后,我们得到一个形状为 (12,) 的数组,在位置 2、5、8、11(从零开始)有非零值。

这表明这个数组的正确形状是 (4, 3)。

以这种方式重塑并对样本和 类 进行平均后,我们得到一个形状为 (3,) 的数组,末尾有非零条目。

这表明该数组的最后一个条目对应于最后一个特征。也就是说0位置的条目不对应任何特征,后面的条目对应特征。