DecisionTreeClassifier 中的 ValueError

ValueError in DecisionTreeClassifier

这是我使用的决策树实现的 link。 https://www.geeksforgeeks.org/decision-tree-implementation-python/

我的数据框仅由“A”和“B”组成,每个都有 512 个值。

data

    1   2   ...      509     510    511    512
A   0.005190    0.00173 ... 0.001730    0.000577    0.002884    0.000577
A   0.000597    0.006567 ... 0.000597   0.000597    0.001194    0.001194
B   0.000582    0.010477 ... 0.001746   0.001164    0.001243    0.003108
A   0.009323    0.001865 ... 0.001865   0.001243    0.003108    0.000622
A   0.000531    0.003186 ... 0.003186   0.001593    0.002124    0.001062

...
X = data.values[:, 1:5]
Y = data.values[:, 0]

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3, random_state = 100)

clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
clf_gini.fit(X_train, y_train)

然而,当我调用fit函数时,它运行在代码的最后一行出现了一个值错误。即使我更改了参数值,它也不起作用。

ValueError                                Traceback (most recent call last)
<ipython-input-19-484db0a3d479> in <module>
      1 # Train with gini
      2 clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
----> 3 clf_gini.fit(X_train, y_train)

~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
    901         """
    902 
--> 903         super().fit(
    904             X, y,
    905             sample_weight=sample_weight,

~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
    189 
    190         if is_classification:
--> 191             check_classification_targets(y)
    192             y = np.copy(y)
    193 

~\anaconda3\envs\myenv\lib\site-packages\sklearn\utils\multiclass.py in check_classification_targets(y)
    181     if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
    182                       'multilabel-indicator', 'multilabel-sequences']:
--> 183         raise ValueError("Unknown label type: %r" % y_type)
    184 
    185 

ValueError: Unknown label type: 'continuous'

老实说,我很困惑。有人可以帮我解决这个问题吗?欣赏一下。

您的 y 标签有问题。如果您的模型应该预测样本是否属于 class AB,您应该根据您的数据集使用索引作为标签 y,因为它包含 class ['A', 'B']:

X = data.values
y = data.index.values

data.values 将 return 所有列值,而 data.index.values 将 return 您将索引作为一个 numpy 数组。