如何使 Java 中的点积方法更快或更有效?

How to make my dot product method in Java faster or more efficient?

我有一个小的 Java 方法,用于在输入向量和矩阵之间执行点积。这是代码:

    public void calcOutput() {
    outputs = new float[output];
    float sum = 0F;

    for(int j = 0; j < output; j++) {
        for(int i = 0; i < input; i++) {
            sum += inputs[i] * weights[j][i];
        }

        outputs[j] = sum;
    }
}

基本上这应该做的是获取我的输入向量 'inputs' 并使用我命名为 "weights" 的矩阵执行点积。然后将输出置于输出向量 'outputs' 中。

我怎样才能让它更快或更有效率?如果有帮助,我的权重矩阵也不需要是矩阵。我只是需要一种方法来轻松访问相应的索引。

谢谢

不,没有比这更好的了。这是您可以实现的最简单的方法,算法遵循良好的内存缓存方法,即外循环遵循数组的外索引,内循环遍历一个子数组中的元素。

也许对内部数组使用临时变量会有所帮助,但我想 JIT 会处理这个问题。

另外,有一个错误,sum变量应该在外层循环的范围内,而不是方法范围内。它需要在外循环的每次迭代中重置:

for(int j = 0; j < output; j++) {
    // NOTE the line:
    float sum = 0;
    // and the reference to inner array:
    byte[] row = weights[j];
    for(int i = 0; i < input; i++) {
        sum += inputs[i] * row[i];
    }

    outputs[j] = sum;
}

这是我会做的。通过反转外循环和内循环,可以减少 inputs 数组中的查找次数。此外,您不需要 sum 变量 - 您可以直接在 outputs 数组中进行添加。

    float[] outputs = new float[output];

    for(int i = 0; i < input; i++) {
        float inputsI = inputs[i];
        for(int j = 0; j < output; j++) {
            outputs[j] += inputsI * weights[j][i];
        }

    }

我希望这只会快一点点。在几乎所有现实世界的应用程序中,不值得担心像这个这样的微小优化。

有一些方法比编写普通的点积要好得多。朴素的实现将被 C2 向量化,但顺序缩减阶段太慢了,向量化乘法的好处被抵消了。在 Java 现在 (JDK10),您可以做的最好的事情是使用部分和来展开以打破数据依赖性。 C2 将发出标量代码,但它会使用一些流水线,您最多可以获得 4 flops/cycle。

float s0 = 0f;
float s1 = 0f;
float s2 = 0f;
float s3 = 0f;
float s4 = 0f;
float s5 = 0f;
float s6 = 0f;
float s7 = 0f;
for (int i = 0; i < size; i += 8) {
  s0 = Math.fma(left[i + 0],  right[i + 0], s0);
  s1 = Math.fma(left[i + 1],  right[i + 1], s1);
  s2 = Math.fma(left[i + 2],  right[i + 2], s2);
  s3 = Math.fma(left[i + 3],  right[i + 3], s3);
  s4 = Math.fma(left[i + 4],  right[i + 4], s4);
  s5 = Math.fma(left[i + 5],  right[i + 5], s5);
  s6 = Math.fma(left[i + 6],  right[i + 6], s6);
  s7 = Math.fma(left[i + 7],  right[i + 7], s7);
}
return s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;

为了尽可能快,您需要使用累加器进行显式矢量化。像这样的代码可以用 Project Panama Vector API.

编写
var sum1 = YMM_FLOAT.zero();
var sum2 = YMM_FLOAT.zero();
var sum3 = YMM_FLOAT.zero();
var sum4 = YMM_FLOAT.zero();
int width = YMM_FLOAT.length();
for (int i = 0; i < size; i += width * 4) {
  sum1 = YMM_FLOAT.fromArray(left, i).fma(YMM_FLOAT.fromArray(right, i), sum1);
  sum2 = YMM_FLOAT.fromArray(left, i + width).fma(YMM_FLOAT.fromArray(right, i + width), sum2);
  sum3 = YMM_FLOAT.fromArray(left, i + width * 2).fma(YMM_FLOAT.fromArray(right, i + width * 2), sum3);
  sum4 = YMM_FLOAT.fromArray(left, i + width * 3).fma(YMM_FLOAT.fromArray(right, i + width * 3), sum4);
}
return sum1.addAll() + sum2.addAll() + sum3.addAll() + sum4.addAll();

查看此 blog post 以获得基准和深入分析。