ML.NET 如何使输入模型通用?

ML.NET how to make input model generic?

我有 3 个多类分类用例,它们的 InputModel 都不同,因为它们具有不同的列和数据结构。我如何重构下面的方法,以便它可以预测任何类型的 InputModel,而无需为了满足 3 种不同的输入数据结构而复制和重复该方法 3 次?

    private List<MulticlassClassificationPrediction> Predict(string modelName, string testDataPath)
    {
        PredictionEngine<InputModel, MulticlassClassificationPrediction> predEngine;

        predEngine = _predEnginePool.GetPredictionEngine(modelName: modelName);

        IDataView dataView = _mlContext.Data.LoadFromTextFile<InputModel>(
                            path: testDataPath,
                            hasHeader: true,
                            separatorChar: ',',
                            allowQuoting: true,
                            allowSparse: false);

        // Use first line of dataset as model input
        // You can replace this with new test data (hardcoded or from end-user application)
        List<InputModel> testDataList = _mlContext.Data.CreateEnumerable<InputModel>(dataView, false).ToList();

        List<MulticlassClassificationPrediction> predictionList = new List<MulticlassClassificationPrediction>();
        foreach (InputModel testData in testDataList)
        {

            MulticlassClassificationPrediction result = predEngine.Predict(testData);

            predictionList.Add(result);

        }

        return predictionList;
    }

如果我理解你的问题是正确的,你有没有机会尝试这样的事情?

private List<MulticlassClassificationPrediction> Predict<TInputModel>(string modelName, string testDataPath) where TInputModel: class, new()
{
    PredictionEngine<TInputModel, MulticlassClassificationPrediction> predEngine;

    predEngine = _predEnginePool.GetPredictionEngine(modelName: modelName);

    IDataView dataView = _mlContext.Data.LoadFromTextFile<TInputModel>(
                        path: testDataPath,
                        hasHeader: true,
                        separatorChar: ',',
                        allowQuoting: true,
                        allowSparse: false);

    // Use first line of dataset as model input
    // You can replace this with new test data (hardcoded or from end-user application)
    var testDataList = _mlContext.Data.CreateEnumerable<TInputModel>(dataView, false).ToList();

    List<MulticlassClassificationPrediction> predictionList = new List<MulticlassClassificationPrediction>();
    foreach (var testData in testDataList)
    {

        MulticlassClassificationPrediction result = predEngine.Predict(testData);

        predictionList.Add(result);

    }

    return predictionList;
}