Xgboost:如何将预测概率转换为多类标签原始名称?

Xgboost: How to convert prediction probabilities to multiclass labels original names?

我正在使用 xgboost multiclass classifier,如下例所示。对于 X_test 数据框中的每一行,模型输出一个列表,列表元素是对应于每个类别 'a'、'b'、'c' 或 'd' 的概率例如[0.44767836 0.2043365 0.15775423 0.19023092].

如何判断列表中的哪个元素对应于哪个 class / 类别(a、b、c 或 d)?我的目标是在数据框 a、b、c、d 上创建 4 个额外的列,并将匹配概率作为每列中的行值。

import numpy as np
import pandas as pd
import xgboost as xgb
import random
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

#Create Example Data
np.random.seed(312)
data = np.random.random((10000, 3))
y = [random.choice('abcd') for _ in range(data.shape[0])]

features = ["x1", "x2", "x3"]
df = pd.DataFrame(data=data, columns=features)
df['y']=y

#Encode target variable
labelencoder = preprocessing.LabelEncoder()
df['y_target'] = labelencoder.fit_transform(df['y'])
    
#Train Test Split    
X_train, X_test, y_train, y_test = train_test_split(df[features], df['y_target'], test_size=0.2, random_state=42, stratify=y)

#Train Model
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

param = {        'objective':'multi:softprob',
                 'random_state': 20,
                 'tree_method': 'gpu_hist',
                 'num_class':4
                }

xgb_model = xgb.train(param, dtrain, 100)

predictions=xgb_model.predict(dtest)

print(predictions)

预测遵循与列标签相同的顺序 0, 1, 2, 3。要获取原始目标名称,请使用 LabelEncoder.

中的 classes_ 属性
import pandas as pd

pd.DataFrame(predictions, columns=labelencoder.classes_)
>>>     
           a           b           c           d
0       0.133130    0.214460    0.569207    0.083203
1       0.232991    0.275813    0.237639    0.253557
2       0.163103    0.248531    0.114013    0.474352
3       0.296990    0.202413    0.157542    0.343054
4       0.199861    0.460732    0.228247    0.111159
... 
1995    0.021859    0.460219    0.235214    0.282708
1996    0.145394    0.182243    0.225992    0.446370
1997    0.128586    0.318980    0.237229    0.315205
1998    0.250899    0.257968    0.274477    0.216657
1999    0.252377    0.236990    0.221835    0.288798