线性回归的梯度下降不起作用
Gradient descent for linear regression is not working
我尝试为一些样本数据使用梯度下降编写线性回归程序。我得到的 theta 值并没有给出最适合数据的值。我已经把数据归一化了。
public class OneVariableRegression {
public static void main(String[] args) {
double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};
double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};
double theta0 = 0.5;
double theta1 = 0.5;
double temp0;
double temp1;
double alpha = 1.5;
double m = x1.length;
System.out.println(m);
double derivative0 = 0;
double derivative1 = 0;
do {
for (int i = 0; i < x1.length; i++) {
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
}
temp0 = theta0 - (alpha * derivative0);
temp1 = theta1 - (alpha * derivative1);
theta0 = temp0;
theta1 = temp1;
//System.out.println("Derivative0 = " + derivative0);
//System.out.println("Derivative1 = " + derivative1);
}
while (derivative0 > 0.0001 || derivative1 > 0.0001);
System.out.println();
System.out.println("theta 0 = " + theta0);
System.out.println("theta 1 = " + theta1);
}
}
是的,您的公式有误。出于某种原因,您在乘法中包含了导数 0 和 1。这严重扭曲了结果。只需删除多余的括号并重试:
derivative0 = derivative0 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m);
derivative1 = derivative1 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m) * x1[i];
输出:
20.0
Derivative0 = 0.010499999999999995
Derivative1 = 0.31809711251208517
Derivative0 = 0.0052500000000000185
Derivative1 = 0.1829058398064968
Derivative0 = -0.007874999999999993
Derivative1 = -0.2129262545589219
theta 0 = 0.4881875
theta 1 = 0.06788495336050987
这是否更符合您的预期?
是的,它是凸的。
您使用的导数来自平方误差函数,它是凸函数,因此除了一个全局最小值外不接受任何局部最小值。 (事实上 ,这类问题甚至可以接受称为 正规方程 的封闭形式的解决方案,只是对于大型问题在数值上不易处理,因此使用梯度下降)
而正确答案在 theta0 = 0.4895
和 theta1 = 0.1652
左右,这对于在任何统计计算环境中检查都是微不足道的。 (如果您持怀疑态度,请参阅答案底部)
下面我指出你代码中的错误,修正错误后,你会在小数点后4位得到上面的正确答案。
您的实施问题:
所以你期望它收敛全局最小值是对的,但是你在实现中遇到了问题
每次重新计算 derivative_i
时,您都忘记将其重置为 0(您所做的是在 do{}while()
中累积跨迭代的导数
你在 do while 循环中需要这个
do {
derivative0 = 0;
derivative1 = 0;
...
}
接下来是这个
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
x1[i]
因素应单独应用于 (theta0 + (theta1 * x1[i]) - y[i]))
。
你的尝试有点混乱,所以让我们把它写得更清楚一些,这样更接近它的数学方程(1/m)sum(y_hat_i - y_i)x_i
:
// You need fresh vars, don't accumulate the derivatives across gradient descent iterations
derivative0 = 0;
derivative1 = 0;
for (int i = 0; i < m; i++) {
derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);
derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];
}
这应该让你足够接近,但是,我发现你的学习率 alpha 有点大。当它 太大 时,你的梯度下降将难以 归零 没有你的全局最优,它会在那里徘徊,但不会完全那里。
double alpha = 0.5;
确认结果
运行 并将其与统计软件的答案进行比较
这是您的 .java 文件的 gist on github。
➜ ~ javac OneVariableRegression.java && java OneVariableRegression
20.0
theta 0 = 0.48950064086914064
theta 1 = 0.16520139788757973
我和R比较了
> x
[1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883
[7] -0.59160798 -0.42257713 -0.25354628 -0.08451543 0.08451543 0.25354628
[13] 0.42257713 0.59160798 0.76063883 0.92966968 1.09870053 1.26773138
[19] 1.43676223 1.60579308
> y
[1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53
[16] 0.61 0.65 0.68 0.74 0.87
> lm(y ~ x)
Call:
lm(formula = y ~ x)
Coefficients:
(Intercept) x
0.4895 0.1652
现在你的代码给出了至少 4 位小数的正确答案。
我尝试为一些样本数据使用梯度下降编写线性回归程序。我得到的 theta 值并没有给出最适合数据的值。我已经把数据归一化了。
public class OneVariableRegression {
public static void main(String[] args) {
double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};
double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};
double theta0 = 0.5;
double theta1 = 0.5;
double temp0;
double temp1;
double alpha = 1.5;
double m = x1.length;
System.out.println(m);
double derivative0 = 0;
double derivative1 = 0;
do {
for (int i = 0; i < x1.length; i++) {
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
}
temp0 = theta0 - (alpha * derivative0);
temp1 = theta1 - (alpha * derivative1);
theta0 = temp0;
theta1 = temp1;
//System.out.println("Derivative0 = " + derivative0);
//System.out.println("Derivative1 = " + derivative1);
}
while (derivative0 > 0.0001 || derivative1 > 0.0001);
System.out.println();
System.out.println("theta 0 = " + theta0);
System.out.println("theta 1 = " + theta1);
}
}
是的,您的公式有误。出于某种原因,您在乘法中包含了导数 0 和 1。这严重扭曲了结果。只需删除多余的括号并重试:
derivative0 = derivative0 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m);
derivative1 = derivative1 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m) * x1[i];
输出:
20.0
Derivative0 = 0.010499999999999995
Derivative1 = 0.31809711251208517
Derivative0 = 0.0052500000000000185
Derivative1 = 0.1829058398064968
Derivative0 = -0.007874999999999993
Derivative1 = -0.2129262545589219
theta 0 = 0.4881875
theta 1 = 0.06788495336050987
这是否更符合您的预期?
是的,它是凸的。
您使用的导数来自平方误差函数,它是凸函数,因此除了一个全局最小值外不接受任何局部最小值。 (事实上 ,这类问题甚至可以接受称为 正规方程 的封闭形式的解决方案,只是对于大型问题在数值上不易处理,因此使用梯度下降)
而正确答案在 theta0 = 0.4895
和 theta1 = 0.1652
左右,这对于在任何统计计算环境中检查都是微不足道的。 (如果您持怀疑态度,请参阅答案底部)
下面我指出你代码中的错误,修正错误后,你会在小数点后4位得到上面的正确答案。
您的实施问题:
所以你期望它收敛全局最小值是对的,但是你在实现中遇到了问题
每次重新计算 derivative_i
时,您都忘记将其重置为 0(您所做的是在 do{}while()
你在 do while 循环中需要这个
do {
derivative0 = 0;
derivative1 = 0;
...
}
接下来是这个
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
x1[i]
因素应单独应用于 (theta0 + (theta1 * x1[i]) - y[i]))
。
你的尝试有点混乱,所以让我们把它写得更清楚一些,这样更接近它的数学方程(1/m)sum(y_hat_i - y_i)x_i
:
// You need fresh vars, don't accumulate the derivatives across gradient descent iterations
derivative0 = 0;
derivative1 = 0;
for (int i = 0; i < m; i++) {
derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);
derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];
}
这应该让你足够接近,但是,我发现你的学习率 alpha 有点大。当它 太大 时,你的梯度下降将难以 归零 没有你的全局最优,它会在那里徘徊,但不会完全那里。
double alpha = 0.5;
确认结果
运行 并将其与统计软件的答案进行比较
这是您的 .java 文件的 gist on github。
➜ ~ javac OneVariableRegression.java && java OneVariableRegression
20.0
theta 0 = 0.48950064086914064
theta 1 = 0.16520139788757973
我和R比较了
> x
[1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883
[7] -0.59160798 -0.42257713 -0.25354628 -0.08451543 0.08451543 0.25354628
[13] 0.42257713 0.59160798 0.76063883 0.92966968 1.09870053 1.26773138
[19] 1.43676223 1.60579308
> y
[1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53
[16] 0.61 0.65 0.68 0.74 0.87
> lm(y ~ x)
Call:
lm(formula = y ~ x)
Coefficients:
(Intercept) x
0.4895 0.1652
现在你的代码给出了至少 4 位小数的正确答案。