ML.Net 的多类分类置信度

Confidence of Multiclass classification with ML.Net

我找到了 ML.NET 的完美介绍:https://www.codeproject.com/Articles/1249611/Machine-Learning-with-ML-Net-and-Csharp-VB-Net。它帮助我解决了 ML.NET.

的一些问题

但其中一个仍然是真实的:

当我向语言检测器(LanguageDetection 示例)发送一些文本时,我总是会收到结果。即使 classification 对非常短的文本片段没有信心。我可以获得有关对 multiclass classification 的信心的信息吗?或者属于某个 class 的概率在使用相邻句子语言的第二轮算法中使用它?

根据@Jon的提示,我修改了CodeProject中的原始示例。这段代码可以通过下面的link找到:https://github.com/sotnyk/LanguageDetector/tree/Code-for-Whosebug-52536943

主要是(按照 Jon 的建议)添加字段:

public float[] Score;

进入class ClassPrediction。

如果此字段存在,我们收到 probabilities/confidences 的 multiclass class 每个 class 的化。

但是我们在原始示例中遇到了另一个困难。它使用浮点值作为类别标签。但它不是分数数组中的索引。要将分数索引映射到类别,我们应该使用方法 TryGetScoreLabelNames:

if (!model.TryGetScoreLabelNames(out var scoreClassNames))
    throw new Exception("Can't get score classes");

但此方法不适用于 class 标签作为浮点值。所以我更改了原始 .tsv 文件和字段 ClassificationData.LanguageClass 和 ClassPrediction.Class 以使用字符串标签作为 class 名称。

未在问题主题中直接提及的其他更改:

  • 更新了 nuget-packages 版本。
  • 我有兴趣使用 lightGBM classifier(它对我来说显示出最好的质量)。但是当前版本的 nuget-package 有一个针对非 NetCore 应用程序的错误。因此,我将示例平台更改为 NetCore20/Standard.
  • 未注释的模型使用 lightGBM classifier。

在名为 Prediction 的应用程序中打印的每种语言的分数。现在,这部分代码如下所示:

internal static async Task<PredictionModel<ClassificationData, ClassPrediction>> PredictAsync(
    string modelPath,
    IEnumerable<ClassificationData> predicts = null,
    PredictionModel<ClassificationData, ClassPrediction> model = null)
{
    if (model == null)
    {
        new LightGbmArguments();
        model = await PredictionModel.ReadAsync<ClassificationData, ClassPrediction>(modelPath);
    }

    if (predicts == null) // do we have input to predict a result?
        return model;

    // Use the model to predict the positive or negative sentiment of the data.
    IEnumerable<ClassPrediction> predictions = model.Predict(predicts);

    Console.WriteLine();
    Console.WriteLine("Classification Predictions");
    Console.WriteLine("--------------------------");

    // Builds pairs of (sentiment, prediction)
    IEnumerable<(ClassificationData sentiment, ClassPrediction prediction)> sentimentsAndPredictions =
        predicts.Zip(predictions, (sentiment, prediction) => (sentiment, prediction));

    if (!model.TryGetScoreLabelNames(out var scoreClassNames))
        throw new Exception("Can't get score classes");

    foreach (var (sentiment, prediction) in sentimentsAndPredictions)
    {
        string textDisplay = sentiment.Text;

        if (textDisplay.Length > 80)
            textDisplay = textDisplay.Substring(0, 75) + "...";

        string predictedClass = prediction.Class;

        Console.WriteLine("Prediction: {0}-{1} | Test: '{2}', Scores:",
            prediction.Class, predictedClass, textDisplay);
        for(var l = 0; l < prediction.Score.Length; ++l)
        {
            Console.Write($"  {l}({scoreClassNames[l]})={prediction.Score[l]}");
        }
        Console.WriteLine();
        Console.WriteLine();
    }
    Console.WriteLine();

    return model;
}

}