MLP Classifier: "ValueError: Unknown label type"

MLP Classifier: "ValueError: Unknown label type"

我正在尝试使用 MLP 分类器创建基本神经网络。 当我使用方法 mlp.fit a 时出现以下错误:

ValueError: Unknown label type: (array([

下面是我的简单代码

df_X_train = df_train[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_train = df_train["Eff_Th"]

df_X_test = df_test[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_test = df_test["Eff_Th"]

X_train = np.asarray(df_X_train, dtype="float64")
Y_train = np.asarray(df_Y_train, dtype="float64")
X_test = np.asarray(df_X_test, dtype="float64")
Y_test = np.asarray(df_Y_test, dtype="float64")

from sklearn.neural_network import MLPClassifier

mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
mlp.fit(X_train, Y_train)

其实我不明白为什么方法fit不喜欢X_trainY_train的float类型。

只是为了让矩阵维度以下的一切都清楚:

X_train.shape --> (720, 3)
Y_train.shape --> (720,)

我希望我问的方式正确,谢谢。

下方完整错误:

> --------------------------------------------------------------------------- ValueError                                Traceback (most recent call
> last) <ipython-input-6-2efb224ab852> in <module>()
>       2 
>       3 mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
> ----> 4 mlp.fit(X_train, Y_train)
>       5 
>       6 #y_pred_train = mlp.predict(X_train)
> 
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in fit(self, X, y)
>     971         """
>     972         return self._fit(X, y, incremental=(self.warm_start and
> --> 973                                             hasattr(self, "classes_")))
>     974 
>     975     @property
> 
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _fit(self, X, y, incremental)
>     329                              hidden_layer_sizes)
>     330 
> --> 331         X, y = self._validate_input(X, y, incremental)
>     332         n_samples, n_features = X.shape
>     333 
> 
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _validate_input(self, X, y, incremental)
>     914         if not incremental:
>     915             self._label_binarizer = LabelBinarizer()
> --> 916             self._label_binarizer.fit(y)
>     917             self.classes_ = self._label_binarizer.classes_
>     918         elif self.warm_start:
> 
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\preprocessing\label.py
> in fit(self, y)
>     282 
>     283         self.sparse_input_ = sp.issparse(y)
> --> 284         self.classes_ = unique_labels(y)
>     285         return self
>     286 
> 
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\utils\multiclass.py
> in unique_labels(*ys)
>      94     _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
>      95     if not _unique_labels:
> ---> 96         raise ValueError("Unknown label type: %s" % repr(ys))
>      97 
>      98     ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys))
> 
> ValueError: Unknown label type: (array([1.        , 0.89534884, 0.58139535, 0.37209302, 0.24418605,
   0.15116279, 0.09302326, 0.23255814, 0.34883721, 0.37209302,
   0.30232558, 0.23255814, 0.18604651, 0.12790698, 0.08139535,
   0.08139535, 0.19767442, 0.27906977, 0.26744186, 0.22093023,
   0.1744186 , 0.11627907, 0.06976744, 0.05813953, 0.1744186 ,
   0.26744186, 0.34883721, 0.40697674, 0.46511628, 0.45348837,
   0.38372093, 0.31395349, 0.26744186, 0.36046512, 0.44186047,
   0.48837209, 0.53488372, 0.48837209, 0.40697674, 0.31395349,
   0.24418605, 0.1744186 , 0.19767442, 0.29069767, 0.36046512,
   0.3255814 , 0.26744186, 0.20930233, 0.13953488, 0.09302326,
   0.04651163, 0.09302326, 0.19767442, 0.29069767, 0.26744186,
   0.20930233, 0.1627907 , 0.11627907, 0.06976744, 0.03488372,
   0.12790698, 0.24418605, 0.31395349, 0.26744186, 0.20930233,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
   0.25581395, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
   0.09302326, 0.05813953, 0.04651163, 0.1627907 , 0.26744186,
   0.30232558, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
   0.05813953, 0.06976744, 0.18604651, 0.27906977, 0.27906977,
   0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
   0.10465116, 0.22093023, 0.29069767, 0.26744186, 0.22093023,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.12790698,
   0.24418605, 0.30232558, 0.25581395, 0.20930233, 0.15116279,
   0.10465116, 0.05813953, 0.03488372, 0.15116279, 0.26744186,
   0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.09302326,
   0.05813953, 0.09302326, 0.20930233, 0.29069767, 0.26744186,
   0.22093023, 0.1627907 , 0.11627907, 0.06976744, 0.02325581,
   0.12790698, 0.23255814, 0.31395349, 0.26744186, 0.20930233,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
   0.25581395, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
   0.10465116, 0.05813953, 0.02325581, 0.11627907, 0.22093023,
   0.29069767, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
   0.04651163, 0.02325581, 0.10465116, 0.20930233, 0.30232558,
   0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.05813953,
   0.03488372, 0.13953488, 0.24418605, 0.31395349, 0.25581395,
   0.20930233, 0.15116279, 0.10465116, 0.15116279, 0.26744186,
   0.3372093 , 0.36046512, 0.30232558, 0.24418605, 0.19767442,
   0.1744186 , 0.25581395, 0.3255814 , 0.38372093, 0.41860465,
   0.34883721, 0.29069767, 0.23255814, 0.1627907 , 0.1744186 ,
   0.27906977, 0.34883721, 0.3255814 , 0.26744186, 0.20930233,
   0.15116279, 0.09302326, 0.04651163, 0.10465116, 0.22093023,
   0.30232558, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
   0.05813953, 0.02325581, 0.12790698, 0.24418605, 0.30232558,
   0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.1627907 ,
   0.26744186, 0.37209302, 0.45348837, 0.51162791, 0.55813953,
   0.59302326, 0.62790698, 0.56976744, 0.48837209, 0.40697674,
   0.36046512, 0.43023256, 0.47674419, 0.48837209, 0.39534884,
   0.30232558, 0.23255814, 0.1627907 , 0.10465116, 0.19767442,
   0.29069767, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
   0.10465116, 0.05813953, 0.02325581, 0.03488372, 0.15116279,
   0.25581395, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
   0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.26744186,
   0.25581395, 0.20930233, 0.1627907 , 0.11627907, 0.06976744,
   0.03488372, 0.        , 0.10465116, 0.20930233, 0.27906977,
   0.22093023, 0.1744186 , 0.12790698, 0.08139535, 0.08139535,
   0.19767442, 0.29069767, 0.36046512, 0.43023256, 0.48837209,
   0.53488372, 0.56976744, 0.60465116, 0.52325581, 0.45348837,
   0.38372093, 0.45348837, 0.51162791, 0.54651163, 0.54651163,
   0.44186047, 0.36046512, 0.27906977, 0.20930233, 0.1744186 ,
   0.25581395, 0.3372093 , 0.3372093 , 0.27906977, 0.22093023,
   0.1627907 , 0.10465116, 0.05813953, 0.06976744, 0.18604651,
   0.27906977, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
   0.08139535, 0.03488372, 0.10465116, 0.22093023, 0.30232558,
   0.27906977, 0.22093023, 0.1744186 , 0.11627907, 0.19767442,
   0.29069767, 0.36046512, 0.40697674, 0.34883721, 0.29069767,
   0.23255814, 0.1744186 , 0.20930233, 0.30232558, 0.36046512,
   0.34883721, 0.29069767, 0.23255814, 0.1744186 , 0.11627907,
   0.06976744, 0.11627907, 0.22093023, 0.30232558, 0.27906977,
   0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.12790698,
   0.24418605, 0.3255814 , 0.27906977, 0.23255814, 0.1744186 ,
   0.12790698, 0.08139535, 0.03488372, 0.        , 0.11627907,
   0.22093023, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
   0.08139535, 0.04651163, 0.02325581, 0.11627907, 0.23255814,
   0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
   0.05813953, 0.08139535, 0.19767442, 0.29069767, 0.29069767,
   0.23255814, 0.18604651, 0.13953488, 0.08139535, 0.04651163,
   0.06976744, 0.18604651, 0.27906977, 0.27906977, 0.23255814,
   0.1744186 , 0.12790698, 0.08139535, 0.04651163, 0.12790698,
   0.24418605, 0.3255814 , 0.27906977, 0.22093023, 0.1744186 ,
   0.11627907, 0.06976744, 0.03488372, 0.13953488, 0.24418605,
   0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
   0.05813953, 0.02325581, 0.13953488, 0.24418605, 0.26744186,
   0.22093023, 0.1744186 , 0.12790698, 0.06976744, 0.03488372,
   0.08139535, 0.19767442, 0.27906977, 0.29069767, 0.24418605,
   0.19767442, 0.13953488, 0.09302326, 0.11627907, 0.23255814,
   0.3255814 , 0.30232558, 0.25581395, 0.19767442, 0.15116279,
   0.09302326, 0.04651163, 0.08139535, 0.19767442, 0.27906977,
   0.31395349, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
   0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
   0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
   0.03488372, 0.15116279, 0.25581395, 0.26744186, 0.20930233,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.01162791,
   0.12790698, 0.23255814, 0.31395349, 0.29069767, 0.24418605,
   0.18604651, 0.13953488, 0.09302326, 0.05813953, 0.1744186 ,
   0.27906977, 0.34883721, 0.29069767, 0.23255814, 0.1744186 ,
   0.11627907, 0.06976744, 0.09302326, 0.19767442, 0.30232558,
   0.31395349, 0.26744186, 0.20930233, 0.15116279, 0.10465116,
   0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
   0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
   0.08139535, 0.20930233, 0.29069767, 0.26744186, 0.20930233,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.09302326,
   0.20930233, 0.27906977, 0.23255814, 0.18604651, 0.13953488,
   0.09302326, 0.04651163, 0.05813953, 0.18604651, 0.26744186,
   0.3372093 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
   0.09302326, 0.1744186 , 0.27906977, 0.34883721, 0.30232558,
   0.24418605, 0.18604651, 0.13953488, 0.08139535, 0.03488372,
   0.04651163, 0.1627907 , 0.26744186, 0.26744186, 0.22093023,
   0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.03488372,
   0.15116279, 0.25581395, 0.27906977, 0.22093023, 0.1744186 ,
   0.12790698, 0.08139535, 0.03488372, 0.01162791, 0.12790698,
   0.23255814, 0.29069767, 0.24418605, 0.19767442, 0.13953488,
   0.09302326, 0.05813953, 0.05813953, 0.1744186 , 0.27906977,
   0.29069767, 0.24418605, 0.18604651, 0.13953488, 0.09302326,
   0.11627907, 0.23255814, 0.30232558, 0.34883721, 0.29069767,
   0.24418605, 0.18604651, 0.12790698, 0.15116279, 0.25581395,
   0.3255814 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
   0.09302326, 0.12790698, 0.22093023, 0.30232558, 0.25581395,
   0.20930233, 0.1627907 , 0.11627907, 0.05813953, 0.02325581,
   0.05813953, 0.1744186 , 0.26744186, 0.22093023, 0.1744186 ,
   0.12790698, 0.08139535, 0.04651163, 0.01162791, 0.11627907,
   0.22093023, 0.25581395, 0.22093023, 0.1744186 , 0.12790698,
   0.08139535, 0.03488372, 0.08139535, 0.19767442, 0.27906977,
   0.34883721, 0.29069767, 0.24418605, 0.18604651, 0.13953488,
   0.10465116, 0.22093023, 0.30232558, 0.3255814 , 0.27906977,
   0.22093023, 0.1627907 , 0.10465116, 0.05813953, 0.02325581,
   0.12790698, 0.24418605, 0.29069767, 0.24418605, 0.19767442,
   0.13953488, 0.09302326, 0.05813953, 0.02325581, 0.10465116,
   0.22093023, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
   0.09302326, 0.05813953, 0.02325581, 0.06976744, 0.18604651,
   0.27906977, 0.25581395, 0.20930233, 0.1627907 , 0.10465116,
   0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.25581395,
   0.3255814 , 0.38372093, 0.44186047, 0.41860465, 0.34883721,
   0.29069767, 0.24418605, 0.25581395, 0.34883721, 0.41860465,
   0.46511628, 0.5       , 0.51162791, 0.41860465, 0.3372093 ,
   0.26744186, 0.20930233, 0.20930233, 0.30232558, 0.37209302,
   0.36046512, 0.29069767, 0.22093023, 0.15116279, 0.10465116,
   0.09302326, 0.19767442, 0.27906977, 0.25581395, 0.20930233,
   0.1627907 , 0.11627907, 0.06976744, 0.02325581, 0.08139535,
   0.19767442, 0.26744186, 0.22093023, 0.1744186 , 0.13953488,
   0.09302326, 0.04651163, 0.02325581, 0.13953488, 0.24418605,
   0.26744186, 0.22093023, 0.1744186 , 0.12790698, 0.08139535,
   0.1744186 , 0.26744186, 0.34883721, 0.40697674, 0.46511628,
   0.41860465, 0.34883721, 0.27906977, 0.22093023, 0.18604651,
   0.27906977, 0.34883721, 0.37209302, 0.30232558, 0.24418605,
   0.1744186 , 0.11627907, 0.06976744, 0.03488372, 0.15116279]),)

看起来你需要 MLPRegressor 而不是 MLPClassifier,当你说你的目标变量需要精度为浮点数时。