从 xgboost.dump 中找到二叉树的所有路径

find all path for binary tree from xgboost.dump

我有一个 xgboost.dump 许多树的文本文件。 我想找到所有路径以获得每条路径的价值。 这是其中一棵树。

tree[0]:
0:[a<0.966398] yes=1,no=2,missing=1
    1:[b<0.323071] yes=3,no=4,missing=3
        3:[c<0.461248] yes=7,no=8,missing=7
            7:leaf=0.00972768
            8:leaf=-0.0179376
        4:[a<0.379082] yes=9,no=10,missing=9
            9:leaf=0.0146003
            10:leaf=0.0454369
    2:[b<0.322352] yes=5,no=6,missing=5
        5:[c<0.674868] yes=11,no=12,missing=11
            11:leaf=0.0497964
            12:leaf=0.00953781
        6:[f<0.598267] yes=13,no=14,missing=13
            13:leaf=0.0504545
            14:leaf=0.0867654

我想把所有路径都改成

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654

我已经尝试列出所有可能的路径,例如

array([[ 0,  1,  3,  7],
       [ 0,  1,  3,  8],
       [ 0,  1,  4,  9],
       [ 0,  1,  4, 10],
       [ 0,  2,  5, 11],
       [ 0,  2,  5, 12],
       [ 0,  2,  6, 13],
       [ 0,  2,  6, 14]])

但是这种方式一旦max_depth更高就会出错,一些分支会停止增长,路径会出错。 所以我需要解析文本文件中的 yes, no 来生成真实的、正确的路径。 有什么建议么? 谢谢!

这是我使用 R 实现解决此问题的方法。其他语言的用户可以按照逻辑进行实物复制。

首先,我从 xgb.model.dt.tree() 生成的模型转储文件开始。

然后,我编写了一个函数来解析从任意节点到转储模型的单个树中的最终父节点的有效路径。

稍后,我使用 purrr::by_row() 将此函数应用于模型转储中的所有终端节点 "Leaf" 记录,并根据需要转换结果。

这个函数有两个参数,一个是它正在测试的树,另一个是终端节点的标识。它遵循以下一般步骤:

  1. 从每棵树的目标(终端)节点开始,在 c("Yes" "No", "Missing") 决策分裂。
  2. 将此有效的父节点 ID 连接到一个向量中,该向量将用于跟踪从目标节点到最终父节点的路径的每个步骤。此向量在函数完成时 returned。
  3. 接下来,对链上的每个节点重复 "who is my parent" 步骤,直到路径到达最终父节点(此节点 ID 始终以“-0”结尾),同时为每个新步骤更新路径向量在链中。
  4. 一旦函数到达终端节点,return() 路径。

在我的例子中,我使用 purrr::by_row() 将此函数应用于模型转储中的所有 "Leaf" 节点,同时 .collat​​ing = "rows" 将路径表示为附加行输出。

这也很可能不是最快的方法。

nrounds 的增加或 xgb.booster 模型中的 max_depth 将导致此过程的运行时间增加。您可以使用树的一个子集(xgb.model.dt.tree() 的参数 n_first_tree = N)来开发您的方法,以估计在最终模型中解析出整个终端节点路径所需的时间。在我的例子中,在 max_depth = 5 处拥有约 500 棵树的模型可能需要 30 分钟以上的时间。