如何在逻辑回归中绘制三个 类 的决策边界?

How to plot descion boundary for three classes in logistic regression?

我想用三个 类 为虹膜数据绘制决策边界。但是,我不知道怎么画。

from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()

x_index = 0
y_index = 1

formatter = plt.FuncFormatter(lambda i, *args: iris.target_names[int(i)])

plt.figure(figsize=(5, 4))
plt.scatter(iris.data[:, x_index], iris.data[:, y_index], c=iris.target)
plt.colorbar(ticks=[0, 1, 2], format=formatter)
plt.xlabel(iris.feature_names[x_index])
plt.ylabel(iris.feature_names[y_index])

plt.tight_layout()
plt.show()

参数$$\theta$$为

theta= array([-0.52952307, -1.14831508,  2.69829141])

前提是我没有得到你的 theta 数组的维度(它似乎是二元分类问题的输出,而你正在考虑具有两个特征和三个 类),这是一个如何绘制决策边界、训练通用多项式逻辑回归模型的示例:

初始设置:

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

custom_cmap = ListedColormap(['#b4a7d6','#93c47d','#fff2cc'])

iris = load_iris()

x_index = 0
y_index = 1
formatter = plt.FuncFormatter(lambda i, *args: iris.target_names[int(i)])

您可以实例化和训练逻辑回归模型(根据您的设置将其拟合到特征 'sepal length (cm)''sepal width (cm)' 上)。

lr = LogisticRegression(multi_class='multinomial', random_state=42, max_iter=500)
lr.fit(iris.data[:, [0, 1]], iris.target)

然后,您可以创建一个网格 [x0_min, x0_max]x[x1_min, x1_max] 点,您将在其上进行预测;最终,您可以绘制训练示例以及定义边界的等高线。

x0, x1 = np.meshgrid(
    np.linspace(iris.data[:, x_index].min(), iris.data[:, x_index].max(), 500).reshape(-1, 1),
    np.linspace(iris.data[:, y_index].min(), iris.data[:, y_index].max(), 500).reshape(-1, 1)
)

X_new = np.c_[x0.ravel(), x1.ravel()]
y_pred = lr.predict(X_new)
zz = y_pred.reshape(x0.shape)

plt.figure(figsize=(10, 5))
plt.contourf(x0, x1, zz, cmap=custom_cmap)
plt.scatter(iris.data[:, x_index], iris.data[:, y_index], c=iris.target)
plt.colorbar(ticks=[0, 1, 2], format=formatter)

plt.xlabel(iris.feature_names[x_index])
plt.ylabel(iris.feature_names[y_index])

plt.tight_layout()
plt.show()

这是来自 sklearn User Guide 的示例。