用两条线拟合曲线的次二次算法

Sub-quadratic algorithm for fitting a curve with two lines

问题是找到实值二维曲线(由点集给出)与由两条线组成的折线的最佳拟合。

一种蛮力方法是为曲线的每个点找到 "left" 和 "right" 线性拟合,并选择误差最小的一对。我可以在遍历曲线点的同时逐步计算两个线性拟合,但我找不到逐步计算误差的方法。因此,这种方法产生二次复杂度。

问题是是否有一种算法可以提供次二次复杂度?

第二个问题是是否有适用于此类算法的方便的 C++ 库?


编辑 用单线拟合,有公式:

m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
b = Σyi/N - m * Σxi/N

其中 m 是斜率,b 是线的偏移量。 有这样一个拟合误差的公式,就是解决问题的最好方法。

免责声明:我不想弄清楚如何在 C++ 中执行此操作,因此我将使用 Python (numpy) 表示法。这些概念是完全可以转移的,因此您可以毫无困难地翻译回您选择的语言。

假设您有一对包含数据点的数组,xy,并且 x 是单调递增的。还假设您总是 select 一个分区点,每个分区至少留下两个元素,因此方程是可解的。

现在您可以计算一些相关量:

N = len(x)

sum_x_left = x[0]
sum_x2_left = x[0] * x[0]
sum_y_left = y[0]
sum_y2_left = y[0] * y[0]
sum_xy_left = x[0] * y[0]

sum_x_right = x[1:].sum()
sum_x2_right = (x[1:] * x[1:]).sum()
sum_y_right = y[1:].sum()
sum_y2_right = (y[1:] * y[1:]).sum()
sum_xy_right = (x[1:] * y[1:]).sum()

我们需要这些数量(O(N) 进行初始化)的原因是您可以直接使用它们来计算线性回归参数的一些众所周知的公式。例如,y = m * x + b 的最佳 mb

给出
μx = Σxi/N
μy = Σyi/N
m = Σ(xi - μx)(yi - μy) / Σ(xi - μx)2
b = μy - m * μx

误差平方和由

给出
e = Σ(yi - m * xi - b)2

这些可以使用简单的代数展开为以下内容:

m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
b = Σyi/N - m * Σxi/N
e = Σyi2 + m2 * Σxi2 + N * b2 - 2 * m * Σxiyi - 2 * b * Σyi + 2 * m * b * Σxi

因此您可以遍历所有可能性并记录最小值 e:

for p in range(1, N - 3):
    # shift sums: O(1)
    sum_x_left += x[p]
    sum_x2_left += x[p] * x[p]
    sum_y_left += y[p]
    sum_y2_left += y[p] * y[p]
    sum_xy_left += x[p] * y[p]

    sum_x_right -= x[p]
    sum_x2_right -= x[p] * x[p]
    sum_y_right -= y[p]
    sum_y2_right -= y[p] * y[p]
    sum_xy_right -= x[p] * y[p]

    # compute err: O(1)
    n_left = p + 1
    slope_left = (sum_xy_left - sum_x_left * sum_y_left * n_left) / (sum_x2_left - sum_x_left * sum_x_left / n_left)
    intercept_left = sum_y_left / n_left - slope_left * sum_x_left / n_left
    err_left = sum_y2_left + slope_left * slope_left * sum_x2_left + n_left * intercept_left * intercept_left - 2 * (slope_left * sum_xy_left + intercept_left * sum_y_left - slope_left * intercept_left * sum_x_left)

    n_right = N - n_left
    slope_right = (sum_xy_right - sum_x_right * sum_y_right * n_right) / (sum_x2_right - sum_x_right * sum_x_right / n_right)
    intercept_right = sum_y_right / n_right - slope_right * sum_x_right / n_right
    err_right = sum_y2_right + slope_right * slope_right * sum_x2_right + n_right * intercept_right * intercept_right - 2 * (slope_right * sum_xy_right + intercept_right * sum_y_right - slope_right * intercept_right * sum_x_right)

    err = err_left + err_right
    if p == 1 || err < err_min
        err_min = err
        n_min_left = n_left
        n_min_right = n_right
        slope_min_left = slope_left
        slope_min_right = slope_right
        intercept_min_left = intercept_left
        intercept_min_right = intercept_right

您可能还可以进行其他简化,但这足以实现 O(n) 算法。

为了以防万一,这里有一些我用于此类事情的 C 代码。它对疯狂物理学家所说的没有任何补充。

首先,一个公式。如果您通过一些点拟合直线 y^ : x->a*x+b,则错误由以下公式给出:

E = Sum{ sqr(y[i]-y^(x[i])) }/ N = Vy - Cxy*Cxy/Vx
where 
Vx is the variance of the xs
Vy that of the ys 
Cxy the covariance of the as and the ys

下面的代码使用了一个包含均值、方差、协方差和计数的结构。

函数 moms_acc_pt() 在您添加新点时更新这些。函数moms_line() returns a 和b 为行,报错如上。 return 上的 fmax(0,) 是在接近完美拟合的情况下,舍入误差可能会使(数学上非负的)结果为负。

虽然可以有一个从 momentsT 中删除点的函数,但我发现通过复制、在副本中累积点、获取线来决定向哪个 momentT 添加点更容易然后保留点最适合的一侧的副本,另一侧的原件

typedef struct
{   int n;      // number points
    double  xbar,ybar;  // means of x,y
    double  Vx, Vy;     // variances of x,y
    double  Cxy;        // covariance of x,y
}   momentsT;

// update the moments to include the point x,y
void    moms_acc_pt( momentsT* M, double x, double y)
{   M->n += 1;
double  f = 1.0/M->n;
double  dx = x-M->xbar;
double  dy = y-M->ybar;
    M->xbar += f*dx;
    M->ybar += f*dy;
double  g = 1.0 - f;
    M->Vx   = g*(M->Vx  + f*dx*dx);
    M->Cxy  = g*(M->Cxy + f*dx*dy);
    M->Vy   = g*(M->Vy  + f*dy*dy);
}

// return the moments for the combination of A and B (assumed disjoint)
momentsT    moms_combine( const momentsT* A, const momentsT* B)
{
momentsT    C;
    C.n = A->n + B->n;
double  alpha = (double)A->n/(double)C.n;
double  beta = (double)B->n/(double)C.n;
    C.xbar = alpha*A->xbar + beta*B->xbar;
    C.ybar = alpha*A->ybar + beta*B->ybar;
double  dx = A->xbar - B->xbar;
double  dy = A->ybar - B->ybar;
    C.Vx = alpha*A->Vx + beta*B->Vx + alpha*beta*dx*dx;
    C.Cxy= alpha*A->Cxy+ beta*B->Cxy+ alpha*beta*dx*dy;
    C.Vy = alpha*A->Vy + beta*B->Vy + alpha*beta*dy*dy;
    return C;
}

// line is y^ : x -> a*x + b; return Sum{ sqr( y[i] - y^(x[i])) }/N
double  moms_line( momentsT* M, double* a, double *b)
{   *a = M->Cxy/M->Vx;
    *b = M->ybar - *a*M->xbar;
    return fmax( 0.0, M->Vy - *a*M->Cxy);
}