了解多元线性回归的梯度下降 python 实现
Understanding Gradient Descent for Multivariate Linear Regression python implementation
看来下面的代码正确地找到了梯度下降:
def gradientDescent(x, y, theta, alpha, m, numIterations):
xTrans = x.transpose()
for i in range(0, numIterations):
hypothesis = np.dot(x, theta)
loss = hypothesis - y
cost = np.sum(loss ** 2) / (2 * m)
print("Iteration %d | Cost: %f" % (i, cost))
# avg gradient per example
gradient = np.dot(xTrans, loss) / m
# update
theta = theta - alpha * gradient
return theta
现在假设我们有以下示例数据:
对于示例数据的第一行,我们将有:
x = [2104, 5, 1, 45]
、theta = [1,1,1,1]
、y = 460
。
但是,我们没有在行中指定:
hypothesis = np.dot(x, theta)
loss = hypothesis - y
要考虑样本数据的哪一行。那这段代码怎么能正常工作?
首先:恭喜您学习了 Coursera 上的机器学习课程! :)
hypothesis = np.dot(x,theta)
将同时计算所有 x(i) 的假设,将每个 h_theta(x(i)) 保存为 hypothesis
的一行。所以不需要引用一行。
loss = hypothesis - y
也是如此。
这看起来像是 Andrew Ng 出色的机器学习课程中的幻灯片!
代码有效是因为您使用的是矩阵类型(来自 numpy 库?),并且已重载基本运算符(+、-、*、/)来执行矩阵运算 - 因此您不需要遍历每一行。
假设y表示为y = w0 + w1*x1 + w2*x2 + w3*x3 + ...... wn*xn
其中 w0 是截距。 np.dot(x, theta)
假设公式中的截距是如何计算出来的
我假设 X = 表示特征的数据。并且 theta 可以是像 [1,1,1., ] of rowSize(data)
这样的数组
看来下面的代码正确地找到了梯度下降:
def gradientDescent(x, y, theta, alpha, m, numIterations):
xTrans = x.transpose()
for i in range(0, numIterations):
hypothesis = np.dot(x, theta)
loss = hypothesis - y
cost = np.sum(loss ** 2) / (2 * m)
print("Iteration %d | Cost: %f" % (i, cost))
# avg gradient per example
gradient = np.dot(xTrans, loss) / m
# update
theta = theta - alpha * gradient
return theta
现在假设我们有以下示例数据:
对于示例数据的第一行,我们将有:
x = [2104, 5, 1, 45]
、theta = [1,1,1,1]
、y = 460
。
但是,我们没有在行中指定:
hypothesis = np.dot(x, theta)
loss = hypothesis - y
要考虑样本数据的哪一行。那这段代码怎么能正常工作?
首先:恭喜您学习了 Coursera 上的机器学习课程! :)
hypothesis = np.dot(x,theta)
将同时计算所有 x(i) 的假设,将每个 h_theta(x(i)) 保存为 hypothesis
的一行。所以不需要引用一行。
loss = hypothesis - y
也是如此。
这看起来像是 Andrew Ng 出色的机器学习课程中的幻灯片!
代码有效是因为您使用的是矩阵类型(来自 numpy 库?),并且已重载基本运算符(+、-、*、/)来执行矩阵运算 - 因此您不需要遍历每一行。
假设y表示为y = w0 + w1*x1 + w2*x2 + w3*x3 + ...... wn*xn 其中 w0 是截距。 np.dot(x, theta)
假设公式中的截距是如何计算出来的我假设 X = 表示特征的数据。并且 theta 可以是像 [1,1,1., ] of rowSize(data)
这样的数组