实施施特拉森算法
Implementing Strassen's Algorithm
我想在单核上非常快速地乘以矩阵。我浏览了整个网络,发现了一些算法,发现 Strassen 的算法是唯一一个真正由人们实现的算法。我查看了几个示例并得出了以下解决方案。我做了一个简单的基准测试,它生成两个随机填充的 500x500 矩阵。 Strassen 的算法用了 18 秒,而高中算法用了 0.4 秒。其他人在实施算法后非常有前途,那么我的问题是什么,我怎样才能让它更快?
// return C = A * B
private Matrix strassenTimes(Matrix B, int LEAFSIZE) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
if (N <= LEAFSIZE || M <= LEAFSIZE) {
return A.times(B);
}
// make new sub-matrices
int newAcols = (A.N + 1) / 2;
int newArows = (A.M + 1) / 2;
Matrix a11 = new Matrix(newArows, newAcols);
Matrix a12 = new Matrix(newArows, newAcols);
Matrix a21 = new Matrix(newArows, newAcols);
Matrix a22 = new Matrix(newArows, newAcols);
int newBcols = (B.N + 1) / 2;
int newBrows = (B.M + 1) / 2;
Matrix b11 = new Matrix(newBrows, newBcols);
Matrix b12 = new Matrix(newBrows, newBcols);
Matrix b21 = new Matrix(newBrows, newBcols);
Matrix b22 = new Matrix(newBrows, newBcols);
for (int i = 1; i <= newArows; i++) {
for (int j = 1; j <= newAcols; j++) {
a11.setElement(i, j, A.saveGet(i, j)); // top left
a12.setElement(i, j, A.saveGet(i, j + newAcols)); // top right
a21.setElement(i, j, A.saveGet(i + newArows, j)); // bottom left
a22.setElement(i, j, A.saveGet(i + newArows, j + newAcols)); // bottom right
}
}
for (int i = 1; i <= newBrows; i++) {
for (int j = 1; j <= newBcols; j++) {
b11.setElement(i, j, B.saveGet(i, j)); // top left
b12.setElement(i, j, B.saveGet(i, j + newBcols)); // top right
b21.setElement(i, j, B.saveGet(i + newBrows, j)); // bottom left
b22.setElement(i, j, B.saveGet(i + newBrows, j + newBcols)); // bottom right
}
}
Matrix aResult;
Matrix bResult;
aResult = a11.add(a22);
bResult = b11.add(b22);
Matrix p1 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a21.add(a22);
Matrix p2 = aResult.strassenTimes(b11, LEAFSIZE);
bResult = b12.minus(b22); // b12 - b22
Matrix p3 = a11.strassenTimes(bResult, LEAFSIZE);
bResult = b21.minus(b11); // b21 - b11
Matrix p4 = a22.strassenTimes(bResult, LEAFSIZE);
aResult = a11.add(a12); // a11 + a12
Matrix p5 = aResult.strassenTimes(b22, LEAFSIZE);
aResult = a21.minus(a11); // a21 - a11
bResult = b11.add(b12); // b11 + b12
Matrix p6 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a12.minus(a22); // a12 - a22
bResult = b21.add(b22); // b21 + b22
Matrix p7 = aResult.strassenTimes(bResult, LEAFSIZE);
Matrix c12 = p3.add(p5); // c12 = p3 + p5
Matrix c21 = p2.add(p4); // c21 = p2 + p4
aResult = p1.add(p4); // p1 + p4
bResult = aResult.add(p7); // p1 + p4 + p7
Matrix c11 = bResult.minus(p5);
aResult = p1.add(p3); // p1 + p3
bResult = aResult.add(p6); // p1 + p3 + p6
Matrix c22 = bResult.minus(p2);
// Grouping the results obtained in a single matrix:
int rows = c11.nrRows();
int cols = c11.nrColumns();
Matrix C = new Matrix(A.M, B.N);
for (int i = 1; i <= A.M; i++) {
for (int j = 1; j <= B.N; j++) {
int el;
if (i <= rows) {
if (j <= cols) {
el = c11.get(i, j);
} else {
el = c12.get(i, j - cols);
}
} else {
if (j <= cols) {
el = c21.get(i - rows, j);
} else {
el = c22.get(i - rows, j - rows);
}
}
C.setElement(i, j, el);
}
}
return C;
}
小benchmark有如下代码:
int AM, AN, BM, BN;
AM = 500;
AN = BM = 500;
BN = 500;
Matrix a = new Matrix(AM, AN);
Matrix b = new Matrix(BM, BN);
Random random = new Random();
for (int i = 1; i <= AM; i++) {
for (int j = 1; j <= AN; j++) {
a.setElement(i, j, random.nextInt(20));
}
}
for (int i = 1; i <= BM; i++) {
for (int j = 1; j <= BN; j++) {
b.setElement(i, j, random.nextInt(20));
}
}
System.out.println("strassen: A x B");
long tijd = System.currentTimeMillis();
Matrix c = a.strassenTimes(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("normal: A x B");
tijd = System.currentTimeMillis();
Matrix d = a.times(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("nr of different elements = " + c.compare(d));
结果如下:
strassen: A x B
time = 18372
normal: A x B
time = 308
nr of different elements = 0
我知道这是一个低代码,但如果你们能帮助我,我会很高兴 ;)
编辑 1:
为了完整起见,我添加了上面代码使用的一些方法。
public int get(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
return content[r - 1][c - 1];
}
private int saveGet(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
return 0;
}
return content[r - 1][c - 1];
}
public void setElement(int r, int c, int n) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
content[r - 1][c - 1] = n;
}
// return C = A + B
public Matrix add(Matrix B) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(M, N);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C.content[i][j] = A.content[i][j] + B.content[i][j];
}
}
return C;
}
我应该为 Strassen 的算法选择另一个叶大小。因此我做了一个小实验。似乎叶大小 256 最适合问题中包含的代码。在具有不同叶子大小的地块下方,每次都有一个大小为 1025 x 1025 的随机矩阵。
我将 Strassen 的叶大小为 256 的算法与用于矩阵乘法的简单算法进行了比较,看看它是否真的有所改进。事实证明这是一种改进,请参见下面不同大小的随机矩阵的结果(步长为 10,每个大小重复 50 次)。
下面是矩阵乘法的简单算法的代码:
// return C = A * B
public Matrix times(Matrix B) {
Matrix A = this;
if (A.N != B.M) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(A.M, B.N);
for (int i = 0; i < C.M; i++) {
for (int j = 0; j < C.N; j++) {
for (int k = 0; k < A.N; k++) {
C.content[i][j] += (A.content[i][k] * B.content[k][j]);
}
}
}
return C;
}
它仍然认为可以在实现上做其他改进,但事实证明叶子大小是一个非常重要的因素。所有实验都是在 Ubuntu 14.04 上使用机器 运行 完成的,规格如下:
CPU: Intel(R) Core(TM) i7-2600K CPU @ 3.40GHz
Memory: 2 x 4GB DDR3 1333 MHz
我想在单核上非常快速地乘以矩阵。我浏览了整个网络,发现了一些算法,发现 Strassen 的算法是唯一一个真正由人们实现的算法。我查看了几个示例并得出了以下解决方案。我做了一个简单的基准测试,它生成两个随机填充的 500x500 矩阵。 Strassen 的算法用了 18 秒,而高中算法用了 0.4 秒。其他人在实施算法后非常有前途,那么我的问题是什么,我怎样才能让它更快?
// return C = A * B
private Matrix strassenTimes(Matrix B, int LEAFSIZE) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
if (N <= LEAFSIZE || M <= LEAFSIZE) {
return A.times(B);
}
// make new sub-matrices
int newAcols = (A.N + 1) / 2;
int newArows = (A.M + 1) / 2;
Matrix a11 = new Matrix(newArows, newAcols);
Matrix a12 = new Matrix(newArows, newAcols);
Matrix a21 = new Matrix(newArows, newAcols);
Matrix a22 = new Matrix(newArows, newAcols);
int newBcols = (B.N + 1) / 2;
int newBrows = (B.M + 1) / 2;
Matrix b11 = new Matrix(newBrows, newBcols);
Matrix b12 = new Matrix(newBrows, newBcols);
Matrix b21 = new Matrix(newBrows, newBcols);
Matrix b22 = new Matrix(newBrows, newBcols);
for (int i = 1; i <= newArows; i++) {
for (int j = 1; j <= newAcols; j++) {
a11.setElement(i, j, A.saveGet(i, j)); // top left
a12.setElement(i, j, A.saveGet(i, j + newAcols)); // top right
a21.setElement(i, j, A.saveGet(i + newArows, j)); // bottom left
a22.setElement(i, j, A.saveGet(i + newArows, j + newAcols)); // bottom right
}
}
for (int i = 1; i <= newBrows; i++) {
for (int j = 1; j <= newBcols; j++) {
b11.setElement(i, j, B.saveGet(i, j)); // top left
b12.setElement(i, j, B.saveGet(i, j + newBcols)); // top right
b21.setElement(i, j, B.saveGet(i + newBrows, j)); // bottom left
b22.setElement(i, j, B.saveGet(i + newBrows, j + newBcols)); // bottom right
}
}
Matrix aResult;
Matrix bResult;
aResult = a11.add(a22);
bResult = b11.add(b22);
Matrix p1 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a21.add(a22);
Matrix p2 = aResult.strassenTimes(b11, LEAFSIZE);
bResult = b12.minus(b22); // b12 - b22
Matrix p3 = a11.strassenTimes(bResult, LEAFSIZE);
bResult = b21.minus(b11); // b21 - b11
Matrix p4 = a22.strassenTimes(bResult, LEAFSIZE);
aResult = a11.add(a12); // a11 + a12
Matrix p5 = aResult.strassenTimes(b22, LEAFSIZE);
aResult = a21.minus(a11); // a21 - a11
bResult = b11.add(b12); // b11 + b12
Matrix p6 = aResult.strassenTimes(bResult, LEAFSIZE);
aResult = a12.minus(a22); // a12 - a22
bResult = b21.add(b22); // b21 + b22
Matrix p7 = aResult.strassenTimes(bResult, LEAFSIZE);
Matrix c12 = p3.add(p5); // c12 = p3 + p5
Matrix c21 = p2.add(p4); // c21 = p2 + p4
aResult = p1.add(p4); // p1 + p4
bResult = aResult.add(p7); // p1 + p4 + p7
Matrix c11 = bResult.minus(p5);
aResult = p1.add(p3); // p1 + p3
bResult = aResult.add(p6); // p1 + p3 + p6
Matrix c22 = bResult.minus(p2);
// Grouping the results obtained in a single matrix:
int rows = c11.nrRows();
int cols = c11.nrColumns();
Matrix C = new Matrix(A.M, B.N);
for (int i = 1; i <= A.M; i++) {
for (int j = 1; j <= B.N; j++) {
int el;
if (i <= rows) {
if (j <= cols) {
el = c11.get(i, j);
} else {
el = c12.get(i, j - cols);
}
} else {
if (j <= cols) {
el = c21.get(i - rows, j);
} else {
el = c22.get(i - rows, j - rows);
}
}
C.setElement(i, j, el);
}
}
return C;
}
小benchmark有如下代码:
int AM, AN, BM, BN;
AM = 500;
AN = BM = 500;
BN = 500;
Matrix a = new Matrix(AM, AN);
Matrix b = new Matrix(BM, BN);
Random random = new Random();
for (int i = 1; i <= AM; i++) {
for (int j = 1; j <= AN; j++) {
a.setElement(i, j, random.nextInt(20));
}
}
for (int i = 1; i <= BM; i++) {
for (int j = 1; j <= BN; j++) {
b.setElement(i, j, random.nextInt(20));
}
}
System.out.println("strassen: A x B");
long tijd = System.currentTimeMillis();
Matrix c = a.strassenTimes(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("normal: A x B");
tijd = System.currentTimeMillis();
Matrix d = a.times(b);
System.out.println("time = " + (System.currentTimeMillis() - tijd));
System.out.println("nr of different elements = " + c.compare(d));
结果如下:
strassen: A x B
time = 18372
normal: A x B
time = 308
nr of different elements = 0
我知道这是一个低代码,但如果你们能帮助我,我会很高兴 ;)
编辑 1: 为了完整起见,我添加了上面代码使用的一些方法。
public int get(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
return content[r - 1][c - 1];
}
private int saveGet(int r, int c) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
return 0;
}
return content[r - 1][c - 1];
}
public void setElement(int r, int c, int n) {
if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) {
throw new ArrayIndexOutOfBoundsException("matrix is of size (" +
nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")");
}
content[r - 1][c - 1] = n;
}
// return C = A + B
public Matrix add(Matrix B) {
Matrix A = this;
if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(M, N);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C.content[i][j] = A.content[i][j] + B.content[i][j];
}
}
return C;
}
我应该为 Strassen 的算法选择另一个叶大小。因此我做了一个小实验。似乎叶大小 256 最适合问题中包含的代码。在具有不同叶子大小的地块下方,每次都有一个大小为 1025 x 1025 的随机矩阵。
我将 Strassen 的叶大小为 256 的算法与用于矩阵乘法的简单算法进行了比较,看看它是否真的有所改进。事实证明这是一种改进,请参见下面不同大小的随机矩阵的结果(步长为 10,每个大小重复 50 次)。
下面是矩阵乘法的简单算法的代码:
// return C = A * B
public Matrix times(Matrix B) {
Matrix A = this;
if (A.N != B.M) throw new RuntimeException("Illegal matrix dimensions.");
Matrix C = new Matrix(A.M, B.N);
for (int i = 0; i < C.M; i++) {
for (int j = 0; j < C.N; j++) {
for (int k = 0; k < A.N; k++) {
C.content[i][j] += (A.content[i][k] * B.content[k][j]);
}
}
}
return C;
}
它仍然认为可以在实现上做其他改进,但事实证明叶子大小是一个非常重要的因素。所有实验都是在 Ubuntu 14.04 上使用机器 运行 完成的,规格如下:
CPU: Intel(R) Core(TM) i7-2600K CPU @ 3.40GHz
Memory: 2 x 4GB DDR3 1333 MHz