Java 中的神经网络未能反向传播
Neural Network In Java Failing to Back Propogate
我已经为神经网络编写了代码,但是当我训练我的网络时,它没有产生所需的输出(网络没有学习,有时在训练时出现 NaN 值)。我的反向传播算法有什么问题?下面附上我是如何分别推导出权重梯度和偏差梯度的公式的。可以找到完整代码 here.
public double[][] predict(double[][] input) {
if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
throw new IllegalArgumentException("Prediction Error!");
}
this.activations.set(0, input);
for(int i = 1; i < this.activations.size(); i++) {
this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1))));
}
return this.activations.get(this.n-1);
}
public void train(double[][] input, double[][] target) {
//calculate activations
this.predict(input);
//calculate weight gradients
for(int l = 0; l < this.weightGradients.size(); l++) {
for(int i = 0; i < this.weightGradients.get(l).length; i++) {
for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
}
}
}
//calculated bias gradients
for(int l = 0; l < this.biasGradients.size(); l++) {
for(int i = 0; i < this.biasGradients.get(l).length; i++) {
for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target);
}
}
}
//apply gradient
for(int i = 0; i < this.weights.size(); i++) {
this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i)));
}
for(int i = 0; i < this.biases.size(); i++) {
this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i)));
}
}
private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1
double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
if((l + 1) < (this.n - 1)) {
double sum = 0.0;
for(int k = 0; k < this.weights.get(l + 1).length; k++) {
sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i];
}
return ((z * sum) / this.activations.get(l + 1)[i][0]);
} else if((l + 1) == (this.n - 1)) {
return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
}
throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}
您看到的 NaN 是由于下溢,您需要使用 BigDecimal class 而不是 double 以获得更高的精度。参考这些以更好地理解 bigdecimal class java sample use , BigDecimal API Reference
这个问题涉及的数学量加上缺少 data/reproduction 代码使得几乎不可能回答“我的 NaN 在哪里”的原始问题。
相反,我建议您将这个问题重新考虑为一个更简单的问题,“我怎样才能知道我的代码中像 NaN 这样的值来自哪里”。
如果您可以 运行 您的代码 IDE,它们中的大多数将支持条件断点。即断点将在变量达到某个值时暂停您的代码。在你的情况下,我建议 运行 在你的首选 IDE 中使用条件断点检测值为 NaN 的代码。
您可以在此 SO post 中阅读更多有关如何设置它的信息,其中在此线程中很好地提到了 NaN 双重检查的主题:
另一个后续考虑是考虑你需要把这些断点放在哪里。简短的回答是将它们放在计算 double 的任何地方,因为任何这些计算都可能引入 NaN。
为此,我提出以下两条建议:
首先,在当前计算双精度的位置放置一个断点,以查看 NaN 是否来自这些计算。那就是这两个变量:
double z = ...
double sum = ...
其次,将对 gradientOfWeight 的调用重构为 return 到一个临时变量中,然后在这些中间计算上放置一个类似的断点。
所以不用
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
你会:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
拥有这些临时变量更方便,可以让您轻松监控计算,而无需以任何重大方式更改调用。可能有一种不需要中间变量的更聪明的方法来做到这一点,但这个方法似乎最容易监控和解释。
我已经为神经网络编写了代码,但是当我训练我的网络时,它没有产生所需的输出(网络没有学习,有时在训练时出现 NaN 值)。我的反向传播算法有什么问题?下面附上我是如何分别推导出权重梯度和偏差梯度的公式的。可以找到完整代码 here.
public double[][] predict(double[][] input) {
if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
throw new IllegalArgumentException("Prediction Error!");
}
this.activations.set(0, input);
for(int i = 1; i < this.activations.size(); i++) {
this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1))));
}
return this.activations.get(this.n-1);
}
public void train(double[][] input, double[][] target) {
//calculate activations
this.predict(input);
//calculate weight gradients
for(int l = 0; l < this.weightGradients.size(); l++) {
for(int i = 0; i < this.weightGradients.get(l).length; i++) {
for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
}
}
}
//calculated bias gradients
for(int l = 0; l < this.biasGradients.size(); l++) {
for(int i = 0; i < this.biasGradients.get(l).length; i++) {
for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target);
}
}
}
//apply gradient
for(int i = 0; i < this.weights.size(); i++) {
this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i)));
}
for(int i = 0; i < this.biases.size(); i++) {
this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i)));
}
}
private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1
double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
if((l + 1) < (this.n - 1)) {
double sum = 0.0;
for(int k = 0; k < this.weights.get(l + 1).length; k++) {
sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i];
}
return ((z * sum) / this.activations.get(l + 1)[i][0]);
} else if((l + 1) == (this.n - 1)) {
return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
}
throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}
您看到的 NaN 是由于下溢,您需要使用 BigDecimal class 而不是 double 以获得更高的精度。参考这些以更好地理解 bigdecimal class java sample use , BigDecimal API Reference
这个问题涉及的数学量加上缺少 data/reproduction 代码使得几乎不可能回答“我的 NaN 在哪里”的原始问题。
相反,我建议您将这个问题重新考虑为一个更简单的问题,“我怎样才能知道我的代码中像 NaN 这样的值来自哪里”。
如果您可以 运行 您的代码 IDE,它们中的大多数将支持条件断点。即断点将在变量达到某个值时暂停您的代码。在你的情况下,我建议 运行 在你的首选 IDE 中使用条件断点检测值为 NaN 的代码。
您可以在此 SO post 中阅读更多有关如何设置它的信息,其中在此线程中很好地提到了 NaN 双重检查的主题:
另一个后续考虑是考虑你需要把这些断点放在哪里。简短的回答是将它们放在计算 double 的任何地方,因为任何这些计算都可能引入 NaN。
为此,我提出以下两条建议:
首先,在当前计算双精度的位置放置一个断点,以查看 NaN 是否来自这些计算。那就是这两个变量:
double z = ...
double sum = ...
其次,将对 gradientOfWeight 的调用重构为 return 到一个临时变量中,然后在这些中间计算上放置一个类似的断点。
所以不用
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
你会:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
拥有这些临时变量更方便,可以让您轻松监控计算,而无需以任何重大方式更改调用。可能有一种不需要中间变量的更聪明的方法来做到这一点,但这个方法似乎最容易监控和解释。