XGBoost feature_names 不匹配的时间序列

XGBoost feature_names mismatch time series

我正在尝试预测股票趋势,其中 1 是股票上涨,0 是股票在特定日期下跌。我的输入特征是收盘价、成交量、当日趋势,我的输出是第二天的趋势。应用 XGBClassifier() 时遇到错误:

ValueError                                Traceback (most recent call last)
<ipython-input-101-d14cdb520e55> in <module>
      1 val = np.array(test[0, 0]).reshape(1, -1)
      2 
----> 3 pred = model.predict(val)
      4 print(pred[0])

~/opt/anaconda3/lib/python3.8/site-packages/xgboost/sklearn.py in predict(self, data, output_margin, ntree_limit, validate_features, base_margin)
    968         if ntree_limit is None:
    969             ntree_limit = getattr(self, "best_ntree_limit", 0)
--> 970         class_probs = self.get_booster().predict(
    971             test_dmatrix,
    972             output_margin=output_margin,

~/opt/anaconda3/lib/python3.8/site-packages/xgboost/core.py in predict(self, data, output_margin, ntree_limit, pred_leaf, pred_contribs, approx_contribs, pred_interactions, validate_features, training)
   1483 
   1484         if validate_features:
-> 1485             self._validate_features(data)
   1486 
   1487         length = c_bst_ulong()

~/opt/anaconda3/lib/python3.8/site-packages/xgboost/core.py in _validate_features(self, data)
   2058                             ', '.join(str(s) for s in my_missing))
   2059 
-> 2060                 raise ValueError(msg.format(self.feature_names,
   2061                                             data.feature_names))
   2062 

ValueError: feature_names mismatch: ['f0', 'f1', 'f2', 'f3'] ['f0']
expected f2, f1, f3 in input data

我的代码如下:

def xgb_predict(train, val):
    train = np.array(train)
    x, y = train[:, :-1], train[:, -1] 
    model = XGBClassifier()
    model.fit(x, y)
    
    val = np.array(val).reshape(1, -1)
    pred = model.predict(val)
    return pred[0]

xgb_predict(train, test[0, 0])

我在第 8 行收到错误。非常感谢您的帮助:)

编辑:Included a sample of data

测试数据的 selection 列必须像处理训练数据一样完成。这意味着您的最后一行应该是:

xgb_predict(train, test[0, :-1])

所以你可以 select 所有 columns/features 但最后一个是目标值。