Julia 中的线性回归和矩阵除法

Linear regression and matrix division in Julia

众所周知的 OLS 公式是 (X'X)^(-1)X'y,其中 XnxKynx1

在 Julia 中实现这一点的一种方法是 (X'*X)\X'*y

但我发现 X\y 给出了几乎相同的输出,只是出现了微小的计算错误。

他们总是计算相同的东西吗(只要n>k)?如果可以,我应该使用哪一个?

简短回答:不,使用第一个(众所周知的)。

长答案:

线性回归模型是Xβ = y,很容易推导β = X \ y,这是你的第二种方法。然而,在大多数情况下(当 X 不可逆时),这是错误的,因为你不能简单地左乘 X^-1。正确的做法是改为求解β = argmin{‖y - Xβ‖^2},这就引出了第一种方法

为了证明它们并不总是相同的,简单构造一个 X 不可逆的情况:

julia> X = rand(10, 10)
10×10 Array{Float64,2}:
 0.938995  0.32773   0.740556  0.300323   0.98479    0.48808    0.748006   0.798089  0.864154  0.869864
 0.973832  0.99791   0.271083  0.841392   0.743448   0.0951434  0.0144092  0.785267  0.690008  0.494994
 0.356408  0.312696  0.543927  0.951817   0.720187   0.434455   0.684884   0.72397   0.855516  0.120853
 0.849494  0.989129  0.165215  0.76009    0.0206378  0.259737   0.967129   0.733793  0.798215  0.252723
 0.364955  0.466796  0.227699  0.662857   0.259522   0.288773   0.691278   0.421251  0.593215  0.542583
 0.126439  0.574307  0.577152  0.664301   0.60941    0.742335   0.459951   0.516649  0.732796  0.990509
 0.430213  0.763126  0.737171  0.433884   0.85549    0.163837   0.997908   0.586575  0.257428  0.33239
 0.28398   0.162054  0.481452  0.903363   0.780502   0.994575   0.131594   0.191499  0.702596  0.0967979
 0.42463   0.142     0.705176  0.0481886  0.728082   0.709598   0.630134   0.139151  0.423227  0.942262
 0.197805  0.526095  0.562136  0.648896   0.805806   0.168869   0.200355   0.557305  0.69514   0.227137

julia> y = rand(10, 1)
10×1 Array{Float64,2}:
 0.7751785556478308
 0.24185992335144801
 0.5681904264574333
 0.9134364924569847
 0.20167825754443536
 0.5776727022413637
 0.05289808385359085
 0.5841180308242171
 0.2862768657856478
 0.45152080383822746

julia> ((X' * X) ^ -1) * X' * y
10×1 Array{Float64,2}:
 -0.3768345891121706
  0.5900885565174501
 -0.6326640292669291
 -1.3922334538787071
  0.06182039005215956
  1.0342060710792016
  0.045791973670925995
  0.7237081408801955
  1.4256831037950832
 -0.6750765481219443

julia> X \ y
10×1 Array{Float64,2}:
 -0.37683458911228906
  0.5900885565176254
 -0.6326640292676649
 -1.3922334538790346
  0.061820390052523294
  1.0342060710793235
  0.0457919736711274
  0.7237081408802206
  1.4256831037952566
 -0.6750765481220102

julia> X[2, :] = X[1, :]
10-element Array{Float64,1}:
 0.9389947787349187
 0.3277301697101178
 0.7405555185711721
 0.30032257202572477
 0.9847899425069042
 0.48807977638742295
 0.7480061513093117
 0.79808859136911
 0.8641540973071822
 0.8698636291189576

julia> ((X' * X) ^ -1) * X' * y
10×1 Array{Float64,2}:
  0.7456524759867015
  0.06233042922132548
  2.5600126098899256
  0.3182206475232786
 -2.003080524452619
  0.272673133766017
 -0.8550165639656011
  0.40827327221785403
  0.2994419115664999
 -0.37876151249955264

julia> X \ y
10×1 Array{Float64,2}:
  3.852193379477664e15
 -2.097948470376586e15
  9.077766998701864e15
  5.112094484728637e15
 -5.798433818338726e15
 -2.0446050874148052e15
 -3.300267174800096e15
  2.990882423309131e14
 -4.214829360472345e15
  1.60123572911982e15

根据文档,X\y 的结果是(使用符号 \(A, B) 而不是 Xy):

For rectangular A the result is the minimum-norm least squares solution

我猜这是你的情况,因为你假设 n>k(所以你的矩阵不是正方形)。所以你可以放心使用X\y。实际上,使用它比使用标准公式更好,因为即使 X 的等级小于 min(n,k),您也会得到结果,而标准公式 (X'*X)^(-1)*X'*y 将失败或产生数值不稳定的结果如果 X'*X 几乎是单数。

如果 X 是正方形(这不是你的情况)那么我们在文档中有一些不同的规则:

For input matrices A and B, the result X is such that A*X == B when A is square

这意味着如果您的矩阵是奇异的,则 \ 算法会产生错误,或者如果矩阵几乎是奇异的,则会产生数值不稳定的结果(在实践中,通常 lu 内部调用的函数对于一般的密集矩阵可能会抛出 SingularException).

如果你想要一个包罗万象的解决方案(对于方阵和非方阵),那么可以使用 qr(X, Val(true)) \ y

X 是平方时,有唯一解,LU-factorization (with pivoting) 是一种数值稳定的计算方法。这就是反斜杠在这种情况下使用的算法。

X不是平方时(大多数回归问题都是这种情况),则没有唯一解但有唯一的最小二乘解。 QR factorization 求解 Xβ = y 的方法是一种生成最小二乘解的数值稳定方法,在这种情况下 X\y 使用 QR 分解,从而给出 OLS 解。

注意数字稳定这个词。虽然 (X'*X)\X'*y 理论上总是会给出与反斜杠相同的结果,但实际上反斜杠(使用正确的分解选择)会更精确。这是因为分解算法实现为 numerically stable。由于在执行 (X'*X)\X'*y 时会累积浮点错误,因此不建议您将此表格用于任何实际的数值计算。

相反,(X'*X)\X'*y 在某种程度上等同于 SVD 分解,它是最稳定的算法,但也是最昂贵的(事实上,它基本上是写出 Moore-Penrose pseudoinverse which is how an SVD factorization is used to solve a linear system). To directly do an SVD factorization using a pivoted SVD, do svdfact(X) \ y on v0.6 or svd(X) \ y on v0.7. Doing this directly is more stable than (X'*X)\X'*y. Note that qrfact(X) \ y or qr(X) \ y (v0.7) is for QR. See the factorizations portion of the documentation 以获得更多细节所有的选择。