xgboost 模型的内部节点预测

Internal node predictions of xgboost model

是否可以计算 xgboost 模型的内部节点预测? R 包 gbm 提供了对每棵树的内部节点的预测。

然而,xgboost 输出仅显示模型最后一片叶子的预测。

xgboost 输出:

请注意,质量列在第 6 行中有叶节点的最终预测。我也希望每个内部节点都有该值。

   Tree Node  ID    Feature    Split  Yes   No Missing     Quality  Cover
1:    0    0 0-0 Sex=female  0.50000  0-1  0-2     0-1 246.6042790 222.75
2:    0    1 0-1        Age 13.00000  0-3  0-4     0-4  22.3424225 144.25
3:    0    2 0-2   Pclass=3  0.50000  0-5  0-6     0-5  60.1275253  78.50
4:    0    3 0-3      SibSp  2.50000  0-7  0-8     0-7  23.6302433   9.25
5:    0    4 0-4       Fare 26.26875  0-9 0-10     0-9  21.4425507 135.00
6:    0    5 0-5       Leaf       NA <NA> <NA>    <NA>   0.1747126  42.50

R gbm 输出:

在 R gbm 包输出中,预测列包含叶节点 (SplitVar == -1) 和内部节点的值。我想从 xgboost 模型中访问这些值

   SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight   Prediction
0         1   0.000000000        1         8          15      32.564591    445  0.001132514
1         2   9.500000000        2         3           7       3.844470    282 -0.085827382
2        -1   0.119585850       -1        -1          -1       0.000000     15  0.119585850
3         0   1.000000000        4         5           6       3.047926    207 -0.092846157
4        -1  -0.118731665       -1        -1          -1       0.000000    165 -0.118731665
5        -1   0.008846912       -1        -1          -1       0.000000     42  0.008846912
6        -1  -0.092846157       -1        -1          -1       0.000000    207 -0.092846157

问题:

如何访问或计算 xgboost 模型内部节点的预测?我想将它们用于贪婪的穷人版本的 SHAP 分数。

这个问题的解决方案是用 all_stats=True 转储 xgboost json 对象。这将 cover 统计信息添加到输出中,可用于通过内部节点分布叶点:

def _calculate_contribution(node: AnyNode) -> float32:
        if isinstance(node, Leaf):
            return node.contrib
        else:
            return (
                node.left.cover * Node._calculate_contribution(node.left)
                + node.right.cover * Node._calculate_contribution(node.right)
            ) / node.cover

内部贡献是子贡献的加权平均值。使用此方法,生成的结果与使用 pred_contribs=Trueapprox_contribs=True.

调用预测方法时返回的结果完全匹配