Java 中的神经网络未能反向传播

Neural Network In Java Failing to Back Propogate

我已经为神经网络编写了代码,但是当我训练我的网络时,它没有产生所需的输出(网络没有学习,有时在训练时出现 NaN 值)。我的反向传播算法有什么问题?下面附上我是如何分别推导出权重梯度和偏差梯度的公式的。可以找到完整代码 here.

public double[][] predict(double[][] input) {
    if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
        throw new IllegalArgumentException("Prediction Error!");
    }
    this.activations.set(0, input);
    for(int i = 1; i < this.activations.size(); i++) {
        this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1))));
    }
    return this.activations.get(this.n-1);
}

public void train(double[][] input, double[][] target) {
    //calculate activations
    this.predict(input);
    //calculate weight gradients
    for(int l = 0; l < this.weightGradients.size(); l++) {
        for(int i = 0; i < this.weightGradients.get(l).length; i++) {
            for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
                this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
            }
        }
    }
    //calculated bias gradients
    for(int l = 0; l < this.biasGradients.size(); l++) {
        for(int i = 0; i < this.biasGradients.get(l).length; i++) {
            for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
                this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target);
            }
        }
    }
    //apply gradient
    for(int i = 0; i < this.weights.size(); i++) {
        this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i)));
    }
    for(int i = 0; i < this.biases.size(); i++) {
        this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i)));
    }
}

private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1
    double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
    if((l + 1) < (this.n - 1)) {
        double sum = 0.0;
        for(int k = 0; k < this.weights.get(l + 1).length; k++) {
            sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i];
        }
        return ((z * sum) / this.activations.get(l + 1)[i][0]);
    } else if((l + 1) == (this.n - 1)) {
        return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
    }
    throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}

您看到的 NaN 是由于下溢,您需要使用 BigDecimal class 而不是 double 以获得更高的精度。参考这些以更好地理解 bigdecimal class java sample use , BigDecimal API Reference

这个问题涉及的数学量加上缺少 data/reproduction 代码使得几乎不可能回答“我的 NaN 在哪里”的原始问题。

相反,我建议您将这个问题重新考虑为一个更简单的问题,“我怎样才能知道我的代码中像 NaN 这样的值来自哪里”。

如果您可以 运行 您的代码 IDE,它们中的大多数将支持条件断点。即断点将在变量达到某个值时暂停您的代码。在你的情况下,我建议 运行 在你的首选 IDE 中使用条件断点检测值为 NaN 的代码。

您可以在此 SO post 中阅读更多有关如何设置它的信息,其中在此线程中很好地提到了 NaN 双重检查的主题:

另一个后续考虑是考虑你需要把这些断点放在哪里。简短的回答是将它们放在计算 double 的任何地方,因为任何这些计算都可能引入 NaN。

为此,我提出以下两条建议:

首先,在当前计算双精度的位置放置一个断点,以查看 NaN 是否来自这些计算。那就是这两个变量:

double z = ...

double sum = ...

其次,将对 gradientOfWeight 的调用重构为 return 到一个临时变量中,然后在这些中间计算上放置一个类似的断点。

所以不用

this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);

你会:

double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;

拥有这些临时变量更方便,可以让您轻松监控计算,而无需以任何重大方式更改调用。可能有一种不需要中间变量的更聪明的方法来做到这一点,但这个方法似乎最容易监控和解释。