ML .NET MulticlassEvaluationMetrics 始终在 testSet 上评估 0

ML .NET MulticlassEvaluationMetrics always evaluates 0 on testSet

我一直在关注 ML .NET 教程中的这个示例:https://github.com/dotnet/samples/tree/master/machine-learning/tutorials/GitHubIssueClassification

并构建了我自己的这个示例版本,它从 .xlsx(不同的数据集)读取数据并将其拆分为训练集和测试集。它运行良好并做出正确的预测,但我终其一生都无法弄清楚为什么当我将 _testSet 馈入其中时评估指标(每个参数)总是显示 0。当我输入 _trainSet 时,它的计算结果为 1,这是预期的。

即使我设置 TestFraction == 0.5,它的计算结果仍然为 0。

using System;
using System.Data;
using System.Data.OleDb;
using System.Collections.Generic;
using System.Linq;
using System.IO;
using Microsoft.ML;

namespace Test.Repository
{
    public class SearchEntry
    {
        [LoadColumn(0)]
        public string Topic { get; set; }
        [LoadColumn(1)]
        public string Subject { get; set; }
    }

    public class SearchPrediction
    {
        [ColumnName("PredictedLabel")]
        public string Topic;
    }

    public class Googler
    {
        private static string _appPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]);
        public string SourceExcel { get; set; } = @"..\..\..\..\Test.Repository\model\in_data.xlsx";
        public string ModelSavePath { get; set; } = @"..\..\..\..\Test.Repository\model\model";
        public double TestFraction { get; set; } = 0.2d;
        private static IDataView _trainingDataView;
        private static MLContext _mlContext;
        private static ITransformer _trainedModel;
        private static IEstimator<ITransformer> pipeline;
        private static PredictionEngine<SearchEntry, SearchPrediction> _predEngine;
        private static List<SearchEntry> _trainSet;
        private static List<SearchEntry> _testSet;

        public void LoadModelData()
        {
            _mlContext = new MLContext(seed: 0);
            var dt = Heplers.Excel.Query(SourceExcel, "SELECT * FROM [data$]");
            var searchEntries = dt.AsEnumerable()
                .Select(r => new SearchEntry { Topic = (string)r["Topic"], Subject = (string)r["Subject"] });
            var dataview = _mlContext.Data.LoadFromEnumerable(searchEntries);
            var split = _mlContext.Data
                .TrainTestSplit(dataview, testFraction: TestFraction,
                samplingKeyColumnName: "Topic");
            _trainSet = _mlContext.Data
                .CreateEnumerable<SearchEntry>(split.TrainSet, reuseRowObject: false).ToList();
            _testSet = _mlContext.Data
                .CreateEnumerable<SearchEntry>(split.TestSet, reuseRowObject: false).ToList();
            _trainingDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_trainSet);
        }

        public void ProcessData()
        {
            Console.WriteLine($"=============== Processing Data ===============");
            pipeline = _mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "Topic", outputColumnName: "Label")
                            .Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: "Subject", outputColumnName: "SubjectFeaturized"))
                            .Append(_mlContext.Transforms.Concatenate("Features", "SubjectFeaturized"))
                            .AppendCacheCheckpoint(_mlContext);
            Console.WriteLine($"=============== Finished Processing Data ===============");
        }

        public void BuildAndTrainModel()
        {
            var trainingPipeline = pipeline
                    .Append(_mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated("Label", "Features"))
                    .Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
            Console.WriteLine($"=============== Training the model  ===============");
            _trainedModel = trainingPipeline.Fit(_trainingDataView);
            Console.WriteLine($"=============== Finished Training the model Ending time: {DateTime.Now.ToString()} ===============");
        }
        public void Evaluate()
        {
            Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now.ToString()} ===============");
            var testDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_testSet);
            var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel.Transform(testDataView));
            Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now.ToString()} ===============");
            Console.WriteLine($"*************************************************************************************************************");
            Console.WriteLine($"*       Metrics for Multi-class Classification model - Test Data     ");
            Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
            Console.WriteLine($"*       MicroAccuracy:    {testMetrics.MicroAccuracy:0.###}");
            Console.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
            Console.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
            Console.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
            Console.WriteLine($"*************************************************************************************************************");
        }
    }
}

输出结果如下:

*************************************************************************************************************
*       Metrics for Multi-class Classification model - Test Data     
*------------------------------------------------------------------------------------------------------------
*       MicroAccuracy:    0
*       MacroAccuracy:    0
*       LogLoss:          
*       LogLossReduction: NaN
*************************************************************************************************************

已切换

var split = _mlContext.Data
                .TrainTestSplit(dataview, testFraction: TestFraction, samplingKeyColumnName: "Topic");

var split = _mlContext.Data
               .TrainTestSplit(dataview, testFraction: TestFraction);

使用 samplingKeyColumnName: "Topic" 我的测试集只有 2 个独特的主题,没有它有 6 个。因此指标很差。

但我还是不喜欢这个结果。我总共有 10 个独特的主题,感觉测试集必须至少有每个主题的一些条目。 Microsoft.ML TrainTestSplit 似乎不能保证这一点。

写了一个自定义拆分器:

        private (List<SearchEntry> TrainSet, List<SearchEntry> TestSet) TrainTestSplit(List<SearchEntry> searchEntries, double testFraction)
        {
            var rand = new Random();
            var testSet = searchEntries.AsEnumerable()
                .Select(r => new { Random = rand.Next(), Entry = r })
                .OrderBy(r => r.Random)
                .Select(r => r.Entry)
                .GroupBy(r => r.Topic)
                .Select(r => r.Take((int)Math.Ceiling(searchEntries.Where(e => e.Topic == r.Key).Count() * testFraction)))
                .SelectMany(r => r)
                .ToList();
            var trainSet = searchEntries.Except(testSet).ToList();
            return (trainSet, testSet);
        }