从 ML.Net fastree 模型获取 RegressionTree

Getting RegressionTree from ML.Net fastree model

祝您今天过得愉快。

我在查看来自 ML.Net FastTree

的决策树模型的内部结构时遇到了问题

我按照 MS 的说明制作了我的模型。

https://docs.microsoft.com/en-us/dotnet/machine-learning/tutorials/predict-prices

MLContext mlContext = new MLContext(seed: 0);
var model = Train(mlContext, _trainDataPath);

我得到的只是一个回归模型,不是决策树结构。

我想从中导出一个合适的 "Tree" 结构,所以你能帮我找到一个解决方案吗?提前致谢。

我认为获得树结构最简单的方法是在'model'训练后设置一个断点,然后在Watch/Autos window中检查模型Visual Studio.

该模型可能是一系列转换器,最后一个转换器是 'decision tree prediction transformer',您可以进一步检查以获得树结构(您需要深入研究 'model parameters', 最终你会发现 TreeEnsemble).

这将是一个决策树列表,而不仅仅是一个。

我们现在有一个 API 可以让您检索梯度提升决策树,请参阅下面的示例:

public void InspectFastTreeModelParameters()
{
    var mlContext = new MLContext(seed: 1);

    var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(TestCommon.GetDataPath(DataDir, TestDatasets.Sentiment.trainFilename),
        hasHeader: TestDatasets.Sentiment.fileHasHeader,
        separatorChar: TestDatasets.Sentiment.fileSeparator,
        allowQuoting: TestDatasets.Sentiment.allowQuoting);



    // Create a training pipeline.
    var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
        .AppendCacheCheckpoint(mlContext)
        .Append(mlContext.BinaryClassification.Trainers.FastTree(
            new FastTreeBinaryTrainer.Options{ NumberOfLeaves = 5, NumberOfTrees= 3, NumberOfThreads = 1 }));



    // Fit the pipeline.
    var model = pipeline.Fit(data);



    // Extract the boosted tree model.
    var fastTreeModel = model.LastTransformer.Model.SubModel;



    // Extract the learned GBDT model.
    var treeCollection = fastTreeModel.TrainedTreeEnsemble;



    // Make sure the tree models were formed as expected.
    Assert.Equal(3, treeCollection.Trees.Count);
    Assert.Equal(3, treeCollection.TreeWeights.Count);
    Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight));
    Assert.All(treeCollection.Trees, tree =>
    {
        Assert.Equal(5, tree.NumberOfLeaves);
        Assert.Equal(4, tree.NumberOfNodes);
        Assert.Equal(tree.SplitGains.Count, tree.NumberOfNodes);
        Assert.Equal(tree.NumericalSplitThresholds.Count, tree.NumberOfNodes);
        Assert.All(tree.CategoricalSplitFlags, flag => Assert.False(flag));
        Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Count);
        Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Count);
    });



    // Add baselines for the model.
    // Verify that there is no bias.
    Assert.Equal(0, treeCollection.Bias);
    // Check the parameters of the final tree.
    var finalTree = treeCollection.Trees[2];
    Assert.Equal(finalTree.LeftChild, new int[] { 2, -2, -1, -3 });
    Assert.Equal(finalTree.RightChild, new int[] { 1, 3, -4, -5 });
    Assert.Equal(finalTree.NumericalSplitFeatureIndexes, new int[] { 14, 294, 633, 266 });
    var expectedSplitGains = new double[] { 0.52634223978445616, 0.45899249367725858, 0.44142707650267105, 0.38348634823264854 };
    var expectedThresholds = new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f };
    for (int i = 0; i < finalTree.NumberOfNodes; ++i)
    {
        Assert.Equal(expectedSplitGains[i], finalTree.SplitGains[i], 6);
        Assert.Equal(expectedThresholds[i], finalTree.NumericalSplitThresholds[i], 6);
    }
}