DecisionTreeClassifier 中的 ValueError
ValueError in DecisionTreeClassifier
这是我使用的决策树实现的 link。
https://www.geeksforgeeks.org/decision-tree-implementation-python/
我的数据框仅由“A”和“B”组成,每个都有 512 个值。
data
1 2 ... 509 510 511 512
A 0.005190 0.00173 ... 0.001730 0.000577 0.002884 0.000577
A 0.000597 0.006567 ... 0.000597 0.000597 0.001194 0.001194
B 0.000582 0.010477 ... 0.001746 0.001164 0.001243 0.003108
A 0.009323 0.001865 ... 0.001865 0.001243 0.003108 0.000622
A 0.000531 0.003186 ... 0.003186 0.001593 0.002124 0.001062
...
X = data.values[:, 1:5]
Y = data.values[:, 0]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3, random_state = 100)
clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
clf_gini.fit(X_train, y_train)
然而,当我调用fit
函数时,它运行在代码的最后一行出现了一个值错误。即使我更改了参数值,它也不起作用。
ValueError Traceback (most recent call last)
<ipython-input-19-484db0a3d479> in <module>
1 # Train with gini
2 clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
----> 3 clf_gini.fit(X_train, y_train)
~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
901 """
902
--> 903 super().fit(
904 X, y,
905 sample_weight=sample_weight,
~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
189
190 if is_classification:
--> 191 check_classification_targets(y)
192 y = np.copy(y)
193
~\anaconda3\envs\myenv\lib\site-packages\sklearn\utils\multiclass.py in check_classification_targets(y)
181 if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
182 'multilabel-indicator', 'multilabel-sequences']:
--> 183 raise ValueError("Unknown label type: %r" % y_type)
184
185
ValueError: Unknown label type: 'continuous'
老实说,我很困惑。有人可以帮我解决这个问题吗?欣赏一下。
您的 y
标签有问题。如果您的模型应该预测样本是否属于 class A
或 B
,您应该根据您的数据集使用索引作为标签 y,因为它包含 class ['A', 'B']
:
X = data.values
y = data.index.values
data.values
将 return 所有列值,而 data.index.values
将 return 您将索引作为一个 numpy 数组。
这是我使用的决策树实现的 link。 https://www.geeksforgeeks.org/decision-tree-implementation-python/
我的数据框仅由“A”和“B”组成,每个都有 512 个值。
data
1 2 ... 509 510 511 512
A 0.005190 0.00173 ... 0.001730 0.000577 0.002884 0.000577
A 0.000597 0.006567 ... 0.000597 0.000597 0.001194 0.001194
B 0.000582 0.010477 ... 0.001746 0.001164 0.001243 0.003108
A 0.009323 0.001865 ... 0.001865 0.001243 0.003108 0.000622
A 0.000531 0.003186 ... 0.003186 0.001593 0.002124 0.001062
...
X = data.values[:, 1:5]
Y = data.values[:, 0]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3, random_state = 100)
clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
clf_gini.fit(X_train, y_train)
然而,当我调用fit
函数时,它运行在代码的最后一行出现了一个值错误。即使我更改了参数值,它也不起作用。
ValueError Traceback (most recent call last)
<ipython-input-19-484db0a3d479> in <module>
1 # Train with gini
2 clf_gini = DecisionTreeClassifier(criterion = "gini", random_state = 100,max_depth=3, min_samples_leaf=5)
----> 3 clf_gini.fit(X_train, y_train)
~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
901 """
902
--> 903 super().fit(
904 X, y,
905 sample_weight=sample_weight,
~\anaconda3\envs\myenv\lib\site-packages\sklearn\tree\_classes.py in fit(self, X, y, sample_weight, check_input, X_idx_sorted)
189
190 if is_classification:
--> 191 check_classification_targets(y)
192 y = np.copy(y)
193
~\anaconda3\envs\myenv\lib\site-packages\sklearn\utils\multiclass.py in check_classification_targets(y)
181 if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
182 'multilabel-indicator', 'multilabel-sequences']:
--> 183 raise ValueError("Unknown label type: %r" % y_type)
184
185
ValueError: Unknown label type: 'continuous'
老实说,我很困惑。有人可以帮我解决这个问题吗?欣赏一下。
您的 y
标签有问题。如果您的模型应该预测样本是否属于 class A
或 B
,您应该根据您的数据集使用索引作为标签 y,因为它包含 class ['A', 'B']
:
X = data.values
y = data.index.values
data.values
将 return 所有列值,而 data.index.values
将 return 您将索引作为一个 numpy 数组。