Julia 中最小二乘坐标下降算法不收敛

Coordinate Descent Algorithm in Julia for Least Squares not converging

作为编写我自己的弹性网络求解器的热身,我正在尝试获得使用坐标下降实现的普通最小二乘法的足够快的版本。

我相信我已经正确地实现了坐标下降算法,但是当我使用 "fast" 版本(见下文)时,算法非常不稳定,输出当特征数量与样本数量相比大小适中时,回归系数通常会溢出 64 位浮点数。

线性回归和 OLS

如果 b = A*x,其中 A 是矩阵,x 是未知回归系数的向量,y 是输出,我想找到最小化的 x

||b - 斧||^2

如果A[j]是A的第j列,A[-j]是没有j列的A,A的列被归一化使得||A[j]||^2 = 1 for所有j,然后坐标更新

坐标下降:

x[j]  <--  A[j]^T * (b - A[-j] * x[-j])

我跟着 these notes (page 9-10) 但推导是简单的微积分。

有人指出,与其一直重新计算 A[j]^T(b - A[-j] * x[-j]),一种更快的方法是使用

快速坐标下降:

x[j]  <--  A[j]^T*r + x[j]

其中总残差 r = b - Ax 是在坐标循环之外计算的。这些更新规则的等价性源于 Ax = A[j]*x[j] + A[-j]*x[-j] 和重新排列项。

我的问题是,虽然第二种方法确实更快,但只要特征数量与样本数量相比不小,它对我来说在数值上就非常不稳定。我想知道是否有人可能对为什么会这样有所了解。我应该注意到,第一种方法更稳定,但随着特征数量接近样本数量,它仍然开始与更多标准方法不一致。

茱莉亚代码

下面是两个更新规则的一些 Julia 代码:

function OLS_builtin(A,b)
    x = A\b
    return(x)
end

function OLS_coord_descent(A,b)    
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        for j = 1:P 
            x[j] = dot(A[:,j], b - A[:,1:P .!= j]*x[1:P .!= j])
        end    
    end
    return(x)
end

function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
        end    
    end
    return(x)
end

问题示例

我使用以下内容生成数据:

n = 100
p = 50
σ = 0.1
β_nz = float([i*(-1)^i for i in 1:10])

β = append!(β_nz,zeros(Float64,p-length(β_nz)))
X = randn(n,p); X .-= mean(X,1); X ./= sqrt(sum(abs2(X),1))
y = X*β + σ*randn(n); y .-= mean(y);

这里我使用 p=50,我在 OLS_coord_descent(X,y)OLS_builtin(X,y) 之间得到了很好的一致性,而 OLS_coord_descent_fast(X,y)returns 回归系数的指数大值。

当p小于20时,OLS_coord_descent_fast(X,y)与其他两个一致

猜想

由于 p << n 的情况是一致的,我认为该算法在形式上是正确的,但在数值上不稳定。有没有人想过这个猜测是否正确,如果正确的话,如何在保留算法快速版本的(大部分)性能增益的同时纠正不稳定性?

快速回答:您忘记在每次 x[j] 更新后更新 r。以下是固定函数,其行为类似于 OLS_coord_descent:

function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
            r -= A[:,j]*dot(A[:,j],r)   # Add this line
        end    
    end
    return(x)
end