如何在 ML.NET 中获取 OneHotEncoding 标签

How to get OneHotEncoding labels in ML.NET

将类别列编码为单热编码向量是一项简单的任务。

但是,我不知道如何从提供的代码中获取单热编码标签,因为了解哪个编码列代表类别标签很重要。

所以下面的代码将类别列编码为one-hot编码向量。

//create dataview from the string array 'colVector'
IDataView data = mlContext.Data.LoadFromEnumerable<IrisFlower>(colVector);

//create a pipeline to transform the category into one-hot encoding vector 
var fitData = mlContext.Transforms.Categorical.OneHotEncoding(nameof(IrisFlower.Label)).Fit(data);
var transData = fitData.Transform(data);
var convertedData = mlContext.Data.CreateEnumerable<EncodedIrisFlower>(transData, true);

所以,我的问题是如何从上面的代码中获取类别标签(sentosavirginica、、versicolor)。

GetColumn 方法应该可以帮助您做到这一点。

类似于你的管道,我有下面的。我确实在 IrisData class.

中添加了标签字段
var data = context.Data.LoadFromTextFile<IrisData>("./iris.data", hasHeader: false, separatorChar: ',');

var shuffledData = context.Data.ShuffleRows(data);

var transData = context.Transforms.Categorical.OneHotEncoding("LabelOneHot", nameof(IrisData.Label))
    .Fit(shuffledData)
    .Transform(shuffledData);

从那里我们可以提取列值。

var oneHotLabels = transData.GetColumn<float[]>("LabelOneHot").ToArray();
var originalLabels = transData.GetColumn<string>("Label").ToArray();

对于 Label 列,只需获取不同项目的数组。

var labels = originalLabels.Distinct().ToArray();

然后可以遍历它们以确定基于单一热编码的正确标签。

foreach (var item in oneHotLabels)
{
    var maxItem = Array.IndexOf(item, item.Max());

    Console.WriteLine(labels[maxItem]);
}

希望对您有所帮助!

您要使用的是一种名为GetSlotNames的方法。这将为您提供一个 VBuffer<ReadOnlyMemory<char>>,其中缓冲区中的每个字符串都是 OneHotEncoding 向量中相应索引的标签。

    MLContext mlContext = new MLContext();

    IrisFlower[] colVector = new IrisFlower[]
    {
        new IrisFlower() { Label = "a" },
        new IrisFlower() { Label = "b" },
        new IrisFlower() { Label = "c" }
    };

    IDataView data = mlContext.Data.LoadFromEnumerable<IrisFlower>(colVector);

    var fitData = mlContext.Transforms.Categorical.OneHotEncoding(nameof(IrisFlower.Label)).Fit(data);
    var transData = fitData.Transform(data);
    var convertedData = mlContext.Data.CreateEnumerable<EncodedIrisFlower>(transData, true);

    VBuffer<ReadOnlyMemory<char>> labels = default;
    transData.Schema["Label"].GetSlotNames(ref labels);
    foreach (var label in labels.DenseValues())
    {
        Console.WriteLine(label);
    }