使用 PolynomialFeatures 将多项式拟合到 3D 中的点云

Fit polynomial to point cloud in 3D using PolynomialFeatures

我有一些点表示 3D 中某些粒子的运动。我试图将多项式拟合到这些点,这样我就可以用一条线来表示粒子所走的轨迹。而且,很明显轨道是2次多项式或3次多项式。

根据建议here,我编写了以下代码以将多项式拟合到我的数据中:

%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

# Point cloud
data = np.array([[ 41. ,  57. ,  92. ],[ 39. ,  57.  , 92.4],[ 43. ,  57.  , 91.2], [ 23.,   47. , 119.6],
                 [ 27. ,  47. , 115.2], [ 25. ,  45. , 122. ], [ 25. ,  49. , 114. ],[ 29.,   49. , 109.6],
                 [ 29. ,  47. , 114.4], [ 27. ,  49. , 111.2], [ 23. ,  45. , 125.6], [ 31.,   49.,  106.8],
                 [ 25. ,  47. , 117.6], [ 39. ,  55. ,  95.6],[ 37.  , 53.  , 98.4], [ 35. ,  55. ,  96.8],
                 [ 33. ,  53. , 116.8], [ 23. ,  43. , 132.8], [ 25. ,  41. , 145.2],[ 25. ,  43.,  133.6],
                 [ 29. ,  51. , 106.4],[ 31.  , 53. , 121.2],[ 31., 51. , 104.8],[ 41.,   55.,   93.6],
                 [ 33. ,  51. , 103.6],[ 35.  , 53. ,  99.6],[ 37. ,  55. ,  96.4]])

x = data[:,0]
y = data[:,1]
z = data[:,2]

# sort data to avoid plotting problems
x, y, z = zip(*sorted(zip(x, y, z)))

x = np.array(x)
y = np.array(y)
z = np.array(z)

data_xy = np.array([x,y])

poly = PolynomialFeatures(degree=2)
X_t = poly.fit_transform(data_xy.transpose())

clf = LinearRegression()
clf.fit(X_t, z)
z_pred = clf.predict(X_t)
print(clf.coef_)
print(clf.intercept_)

fig = plt.figure()
ax = plt.subplot(projection='3d')
ax.plot(x, y, z_pred, 'r') # fit line
ax.scatter(x, y, z)
fig.set_dpi(150)

问题是,我得到了一个非常奇怪的结果:

知道发生了什么事吗?

编辑:我希望有一条适合数据的线。例如,我将完全相同的方法应用于不同的数据,这就是我得到的结果:

也可以建议其他方法。谢谢!

我能够解决问题。实际上,我对因变量和自变量的选择就是这里的问题。如果我假设 X 是自变量而 Y 和 Z 是相关的,那么我会得到更好的结果:

from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

def polynomial_regression3d(x, y, z, degree):
    # sort data to avoid plotting problems
    x, y, z = zip(*sorted(zip(x, y, z)))

    x = np.array(x)
    y = np.array(y)
    z = np.array(z)
    
    data_yz = np.array([y,z])
    data_yz = data_yz.transpose()

    polynomial_features= PolynomialFeatures(degree=degree)
    x_poly = polynomial_features.fit_transform(x[:, np.newaxis])

    model = LinearRegression()
    model.fit(x_poly, data_yz)
    y_poly_pred = model.predict(x_poly)

    rmse = np.sqrt(mean_squared_error(data_yz,y_poly_pred))
    r2 = r2_score(data_yz,y_poly_pred)
    print("RMSE:", rmse)
    print("R-squared", r2)
    
    # plot
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.scatter(x, data_yz[:,0], data_yz[:,1])
    ax.plot(x, y_poly_pred[:,0], y_poly_pred[:,1], color='r')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    plt.show()
    fig.set_dpi(150)

但是,当我计算 R-squared 时,它显示的值为 0.8,这还不错,但也许可以改进? 也许,使用加权最小二乘法可以改善这一点。