sklearn svm 拟合不佳

sklearn svm provides a bad fit

我试图将 SVM 绘制到我的示例数据中,但我 运行 遇到了一个问题:该图似乎根本不正确,即 st运行ge,因为我使用了 here 中的示例代码(更具体地说,“发生了什么?”部分)。他们的代码对我来说工作正常,所以我认为问题与我的数据有关。我注意到拟合系数非常小,可以理解,这会破坏线条。

这是可重现的代码。

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
import matplotlib as mpl

plt.figure(figsize=(5,5))
in_cir = lambda x,y: True if x**2 + y**2 <= 4 else False # Checking if point is in the purple circle
f = lambda x,e: 1.16*x + 0.1 + e                         # True function
ran = np.arange(-5,6)       
lsp = np.linspace(-5,5,170)                              # X1 axis
np.random.seed(69)
dots = f(lsp,[np.random.normal(0,1.5) for i in lsp])     # X2 axis
blue_dots, pur_dots, lsp1, lsp2 = [], [], [], []
for i, x in zip(dots, lsp):
  if in_cir(x,i): pur_dots.append(i); lsp2.append(x)     # Getting all purple dots's X1 and X2
  else: blue_dots.append(i); lsp1.append(x)              # Same for blue ones
plt.scatter(lsp1, blue_dots, color='cornflowerblue')
plt.scatter(lsp2, pur_dots, color='magenta')
plt.xlabel('$X_1$', fontsize=15)
plt.ylabel('$X_2$', fontsize=15)


x, y = np.array(list(zip(lsp, dots))), np.where( np.array([in_cir(x,i) for x,i in zip(lsp,dots)]) == True, 'p','b' )
                                                        # On two lines above I made x a 2d array
                                                        # of coordinates for each dot
                                                        # And y is a list of 'b' if the corresponding
                                                        # dot is blue and 'p' otherwise

ft = svm.SVC(kernel='linear', C=1).fit(x, y)            # Fitting svc

                                                        # Here starts the code from the link
w = ft.coef_[0]
print('w', w)                                           # w components are really small
a = -w[0] / w[1]
xx = np.linspace(-5, 5)
yy = a * xx - (ft.intercept_[0]) / w[1]                 # This is where it all goes wrong

b = ft.support_vectors_[0]
yy_down = a * xx + (b[1] - a * b[0])
b = ft.support_vectors_[-1]
yy_up = a * xx + (b[1] - a * b[0])

plt.plot(xx, yy, 'k-')
plt.plot(xx, yy_down, 'k--')
plt.plot(xx, yy_up, 'k--')


plt.ylim(-5, 5.5)                                      # To make it interpretable
plt.xlim(-5, 4.5)                                      # the plot will be squished because of
plt.show()                                             # high values if removed

输出为:

如你所见,结果是悲惨的。如果有人能解释我做错了什么,我将不胜感激。


编辑:我实际上设法做到了这一点。这是我写的代码:

plt.figure(figsize=(7,7))
np.random.seed(420)

ran = np.arange(-5,6)
st = 1
b, p = np.array([ (-3+np.random.normal(0,st), -2.5+np.random.normal(0,st)) for i in range(25) ]+\
[ (2.5+np.random.normal(0,st), 3.5+np.random.normal(0,st)) for i in range(25) ]), np.array([ (np.random.normal(0,st), np.random.normal(0,st)) for i in range(50) ])

plt.scatter(b[:,0], b[:,1], color='cornflowerblue')
plt.scatter(p[:,0], p[:,1], color='magenta')
plt.xlabel('$X_1$', fontsize=15)
plt.ylabel('$X_2$', fontsize=15)


x, y = np.concatenate( (np.concatenate( (b[:25], p) ), b[-25:]) ), [0]*25 + [1]*50 + [0]*25
ft = svm.SVC(kernel='linear').fit(x, y)

by, bx = np.meshgrid([-5, 6], [-5, 6])
bo = ft.decision_function(np.vstack([by.ravel(), bx.ravel()]).T).reshape(bx.shape).T

xx, yy = np.meshgrid(np.arange(-5.1, 4.6, 0.01),
                     np.arange(-5.1, 5.6, 0.01))
Z = ft.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
C = plt.contourf(xx, yy, Z,colors='none', hatches=['.'])
colors=['cornflowerblue', 'magenta']
for j, collection in enumerate(C.collections):
  if j == 0: collection.set_edgecolor(colors[0])
  else: collection.set_edgecolor(colors[1])
plt.contour(bx, by, bo, colors='0', levels=[-1, 0, 1], linestyles=['--', '-', '--'])


plt.ylim(-5, 5.5)
plt.xlim(-5, 4.5)
plt.show()

结果是:

您正在尝试使用线性分类器 data which is not linearly separable 进行区分(即您无法绘制一条直线来分隔两组)。您可以使用另一个内核,例如 RBF:

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
from mlxtend.plotting import plot_decision_regions

plt.figure(figsize=(5,5))
in_cir = lambda x,y: True if x**2 + y**2 <= 4 else False # Checking if point is in the purple circle
f = lambda x,e: 1.16*x + 0.1 + e                         # True function
ran = np.arange(-5,6)
lsp = np.linspace(-5,5,170)                              # X1 axis
np.random.seed(69)
dots = f(lsp,[np.random.normal(0,1.5) for i in lsp])     # X2 axis
blue_dots, pur_dots, lsp1, lsp2 = [], [], [], []
for i, x in zip(dots, lsp):
  if in_cir(x,i): pur_dots.append(i); lsp2.append(x)     # Getting all purple dots's X1 and X2
  else: blue_dots.append(i); lsp1.append(x)              # Same for blue ones

x, y = np.array(list(zip(lsp, dots))), np.where(np.array([in_cir(x,i) for x,i in zip(lsp,dots)]), 'p','b')

y[y == 'b'] = 0  # replacing letters with integers as the plot_decision_regions function accepts only integers
y[y == 'p'] = 1
y = y.astype(int)

ft = svm.SVC(kernel='rbf', C=1).fit(x, y)            # Fitting svc
plot_decision_regions(X=x,
                      y=y,
                      clf=ft,
                      legend=2)
plt.show()