实施施特拉森矩阵乘法算法的问题
Trouble With Implementing Strassen's Algorithm for Matrix Multiplication
在过去的几个小时里,我一直在尝试实现 Strassen 的矩阵乘法算法,但一直无法获得正确的乘积。我认为我的辅助函数之一(helpSub、createProd、helpProduct)可能是我的 strass2 函数的问题或格式(命令顺序等)。欢迎任何提示,因为我完全被难住了。我一直在使用两个 4 x 4 矩阵作为测试矩阵。我已经尝试了很多我在互联网上看到的 p1-p7 和 c1-c4 的变体,但 none 似乎有效。下面是我创建的 class。
/* @author williamnewman
public class strassen2 {
//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
if(x.length == 1 && y.length == 1){
System.out.println("Donezo");
int [][] nu = new int[1][1];
nu[0][0] = x[0][0] * y[0][0];
return nu;
}
else{
int[][] a,b,c,d,e,f,g,h;
int dim = x.length/2;
//Dividing two matrices into 8 sub matrices
System.out.println("A<B<C");
a = helpSub(0,0,x);
C(a);
b = helpSub(0,dim,x);
C(b);
c = helpSub(dim,0,x);
C(c);
d = helpSub(dim,dim,x);
C(d);
e = helpSub(0,0,y);
C(e);
f = helpSub(0,dim,y);
C(f);
g = helpSub(dim,0,y);
C(g);
h = helpSub(dim,dim,y);
C(h);
int[][] p1,p2,p3,p4,p5,p6,p7;
//Creating p1-p7
/
p1 = strass2(a,subtract(f,h));
p2 = strass2(h, add(a,b));
p3 = strass2(e,add(c,d));
p4 = strass2(d,subtract(g,e));
p5 = strass2(add(a,d),add(e,h));
p6 = strass2(subtract(b,d),add(g,h));
p7 = strass2(subtract(a,c),add(e,f));
int [][] prod;
int [][] c1,c2,c3,c4;
//Creating c1-c4
c1 = subtract(add(p6,p5),subtract(p4,p2));
c2 = add(p1,p2);
c3 = add(p3,p4);
c4 = subtract(add(p1,p5),subtract(p3,p7));
C(c1);
System.out.println("C1::");
C(c2);
System.out.println("C2::");
C(c3);
System.out.println("C3::");
C(c4);
System.out.println("C4::");
//CREATES PRODUCT MATRIX
prod = createProd(c1,c2,c3,c4);
return prod;
}
}
//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
int[][] product = new int[c1.length*2][c1.length*2];
int mid = c1.length;
int fin = c1.length * 2;
helpProduct(0,0,mid,mid,product,c1);
helpProduct(0,mid,mid,fin,product,c2);
helpProduct(mid,0,fin,mid,product,c3);
helpProduct(mid,mid,fin,fin,product,c4);
System.out.println();
System.out.println("PRODUCT::!:");
C(product);
return product;
}
//Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
int indR = 0;
int indC = 0;
for(int i = x; i < z1; i++){
indC = 0;
for(int j = y; j < z2; j++){
product[i][j] = a1[indR][indC];
indC++;
}
indR++;
}
}
int[][] helpSub(int x, int y, int[][] mat){
int[][] sub = new int[mat.length/2][mat.length/2];
for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
{
sub[i1][j1] = mat[i2][j2];
// System.out.println(sub[i1][j1]);
}
return sub;
}
//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
MM nu = new MM(a,b);
return nu.product;
}
//Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
int [][] nu = new int[a.length][a[0].length];
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
nu[i][j] = a[i][j] + b[i][j];
}
}
return nu;
}
//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
int [][] sub = new int[a.length][a.length];
//System.out.println("made it");
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
sub[i][j] = a[i][j] - b[i][j];
}
}
return sub;
}
//Prints the matrix
void C(int[][] product){
for(int i = 0; i <product.length; i++){
for(int j = 0; j < product[i].length; j++){
System.out.print(product[j][i] + " ");
}
System.out.println();
}
}
}
如果有任何令人困惑的地方,请告诉我,我会更新问题!
这是主要功能::
public static void main(String[] args) {
int [][]a = {{1,2,3,4},
{4,3,2,1},
{1,2,3,4},
{4,3,2,1}};
int [][]b = {{3,4,5,6},
{3,4,5,6},
{5,4,3,2},
{5,4,3,2}
};
MM a1 = new MM(a,b);
a1.C();
int[][] prod;
System.out.println("----");
strassen2 a2 = new strassen2();
prod = a2.strass2(a,b);
a2.C(prod);
}
}
这是目前的结果(预期结果是显示的第一个 4x4 矩阵,实际结果是显示的最后一个 4x4 矩阵):
EXPECTED:
44 40 36 32
36 40 44 48
44 40 36 32
36 40 44 48
----
ACTUAL::
70 78 50 42
86 86 34 34
30 38 30 38
38 54 38 54
我很确定我的 helpSub() 函数可以正常工作,因为它们生成了更正的 a-h。但是,我在 strass2 递归调用中使用的参数可能有问题。很抱歉,如果它不够具体,我只是有点精疲力尽,并且很好奇是否有人看到任何明显的问题。
抱歉含糊不清,但我似乎已经解决了这个问题。我使用了这个网站上的公式来计算 p1-p7 和 c1-c4。 ([施特拉森矩阵乘法公式][1]
[1]: http://www.stoimen.com/blog/2012/11/26/computer-algorithms-strassens-matrix-multiplication/ )
实施这些公式后,乘积矩阵几乎正确,但 4 或值不正确。然后我将基本情况更改为 x 和 y 的长度等于 2 时,这似乎更正了 4 个关闭的值。对于那些好奇的人,这里是我为 strassen2 class.
修改的代码
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package pkg2a;
/**
*
* @author williamnewman
*/
public class strassen2 {
int [][] strass2(int[][] x, int[][]y){
if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
return multiply(x,y);
}
else{
int[][] a,b,c,d,e,f,g,h;
int dim = x.length/2;
System.out.println("A<B<C");
a = helpSub(0,0,x);
//C(a);
b = helpSub(0,dim,x);
//C(b);
c = helpSub(dim,0,x);
//C(c);
d = helpSub(dim,dim,x);
//C(d);
e = helpSub(0,0,y);
//C(e);
f = helpSub(0,dim,y);
//C(f);
g = helpSub(dim,0,y);
//C(g);
h = helpSub(dim,dim,y);
//C(h);
int[][] p1,p2,p3,p4,p5,p6,p7;
// createSub(x,y,a,b,c,d,e,f,g,h);
int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10;
//MODIFICATION HERE
p1 = strass2(a,subtract(f,h));
p2 = strass2(add(a,b),h);
p3 = strass2(add(c,d),e);
p4 = strass2(d,subtract(g,e));
p5 = strass2(add(a,d),add(e,h));
p6 = strass2(subtract(b,d),add(g,h));
p7 = strass2(subtract(a,c),add(e,f));
int [][] prod;
int [][] c1,c2,c3,c4;
c1 = subtract(add(p5,p4),subtract(p2,p6));
c2 = add(p1,p2);
c3 = add(p3,p4);
c4 = subtract(add(p1,p5),add(p3,p7));
//C(c1);
//System.out.println("C1::");
//C(c2);
//System.out.println("C2::");
//C(c3);
//System.out.println("C3::");
//C(c4);
//System.out.println("C4::");
prod = createProd(c1,c2,c3,c4);
return prod;
}
}
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
int[][] product = new int[c1.length*2][c1.length*2];
int mid = c1.length;
int fin = c1.length * 2;
helpProduct(0,0,mid,mid,product,c1);
helpProduct(0,mid,mid,fin,product,c2);
helpProduct(mid,0,fin,mid,product,c3);
helpProduct(mid,mid,fin,fin,product,c4);
System.out.println();
System.out.println("PRODUCT::!:");
//C(product);
return product;
}
//Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
int indR = 0;
int indC = 0;
for(int i = x; i < z1; i++){
indC = 0;
for(int j = y; j < z2; j++){
product[i][j] = a1[indR][indC];
indC++;
}
indR++;
}
}
/*
void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
int div1R = x.length/2;
int div1C = div1R;
int div2R = div1R;
int div2C = div1R;
a = helpSub(0,0,div1R,div1C,x);
// c(a);
b = helpSub(0,div1C,div1R,x[0].length,x);
//c(b);
c = helpSub(div1R,0,x.length,div1C,x);
//c(c);
d = helpSub(div1R,div1C,x.length,x[0].length,x);
//c(d);
e = helpSub(0,0,div2R,div2C,y);
//c(e);
f = helpSub(0,div2C,div2R,y[0].length,y);
// c(f);
g = helpSub(div2R,0,y.length,div2C,y);
//c(g);
h = helpSub(div2R,div2C,y.length,y[0].length,y);
// c(h);
}
*/
int[][] helpSub(int x, int y, int[][] mat){
int[][] sub = new int[mat.length/2][mat.length/2];
for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
{
sub[i1][j1] = mat[i2][j2];
// System.out.println(sub[i1][j1]);
}
return sub;
}
int[][] multiply(int[][]a,int[][]b){
MM nu = new MM(a,b);
return nu.product;
}
//Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
int [][] nu = new int[a.length][a[0].length];
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
nu[i][j] = a[i][j] + b[i][j];
}
}
return nu;
}
//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
int [][] sub = new int[a.length][a.length];
//System.out.println("made it");
int rows = a.length;
int columns = a[0].length;
for(int i = 0; i < rows; i++){
for(int j = 0; j < columns;j++){
sub[i][j] = a[i][j] - b[i][j];
}
}
return sub;
}
void C(int[][] product){
for(int i = 0; i <product.length; i++){
for(int j = 0; j < product[i].length; j++){
System.out.print(product[i][j] + " ");
}
System.out.println();
}
}
}
在过去的几个小时里,我一直在尝试实现 Strassen 的矩阵乘法算法,但一直无法获得正确的乘积。我认为我的辅助函数之一(helpSub、createProd、helpProduct)可能是我的 strass2 函数的问题或格式(命令顺序等)。欢迎任何提示,因为我完全被难住了。我一直在使用两个 4 x 4 矩阵作为测试矩阵。我已经尝试了很多我在互联网上看到的 p1-p7 和 c1-c4 的变体,但 none 似乎有效。下面是我创建的 class。
/* @author williamnewman
public class strassen2 {
//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
if(x.length == 1 && y.length == 1){
System.out.println("Donezo");
int [][] nu = new int[1][1];
nu[0][0] = x[0][0] * y[0][0];
return nu;
}
else{
int[][] a,b,c,d,e,f,g,h;
int dim = x.length/2;
//Dividing two matrices into 8 sub matrices
System.out.println("A<B<C");
a = helpSub(0,0,x);
C(a);
b = helpSub(0,dim,x);
C(b);
c = helpSub(dim,0,x);
C(c);
d = helpSub(dim,dim,x);
C(d);
e = helpSub(0,0,y);
C(e);
f = helpSub(0,dim,y);
C(f);
g = helpSub(dim,0,y);
C(g);
h = helpSub(dim,dim,y);
C(h);
int[][] p1,p2,p3,p4,p5,p6,p7;
//Creating p1-p7
/
p1 = strass2(a,subtract(f,h));
p2 = strass2(h, add(a,b));
p3 = strass2(e,add(c,d));
p4 = strass2(d,subtract(g,e));
p5 = strass2(add(a,d),add(e,h));
p6 = strass2(subtract(b,d),add(g,h));
p7 = strass2(subtract(a,c),add(e,f));
int [][] prod;
int [][] c1,c2,c3,c4;
//Creating c1-c4
c1 = subtract(add(p6,p5),subtract(p4,p2));
c2 = add(p1,p2);
c3 = add(p3,p4);
c4 = subtract(add(p1,p5),subtract(p3,p7));
C(c1);
System.out.println("C1::");
C(c2);
System.out.println("C2::");
C(c3);
System.out.println("C3::");
C(c4);
System.out.println("C4::");
//CREATES PRODUCT MATRIX
prod = createProd(c1,c2,c3,c4);
return prod;
}
}
//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
int[][] product = new int[c1.length*2][c1.length*2];
int mid = c1.length;
int fin = c1.length * 2;
helpProduct(0,0,mid,mid,product,c1);
helpProduct(0,mid,mid,fin,product,c2);
helpProduct(mid,0,fin,mid,product,c3);
helpProduct(mid,mid,fin,fin,product,c4);
System.out.println();
System.out.println("PRODUCT::!:");
C(product);
return product;
}
//Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
int indR = 0;
int indC = 0;
for(int i = x; i < z1; i++){
indC = 0;
for(int j = y; j < z2; j++){
product[i][j] = a1[indR][indC];
indC++;
}
indR++;
}
}
int[][] helpSub(int x, int y, int[][] mat){
int[][] sub = new int[mat.length/2][mat.length/2];
for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
{
sub[i1][j1] = mat[i2][j2];
// System.out.println(sub[i1][j1]);
}
return sub;
}
//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
MM nu = new MM(a,b);
return nu.product;
}
//Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
int [][] nu = new int[a.length][a[0].length];
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
nu[i][j] = a[i][j] + b[i][j];
}
}
return nu;
}
//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
int [][] sub = new int[a.length][a.length];
//System.out.println("made it");
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
sub[i][j] = a[i][j] - b[i][j];
}
}
return sub;
}
//Prints the matrix
void C(int[][] product){
for(int i = 0; i <product.length; i++){
for(int j = 0; j < product[i].length; j++){
System.out.print(product[j][i] + " ");
}
System.out.println();
}
}
}
如果有任何令人困惑的地方,请告诉我,我会更新问题!
这是主要功能::
public static void main(String[] args) {
int [][]a = {{1,2,3,4},
{4,3,2,1},
{1,2,3,4},
{4,3,2,1}};
int [][]b = {{3,4,5,6},
{3,4,5,6},
{5,4,3,2},
{5,4,3,2}
};
MM a1 = new MM(a,b);
a1.C();
int[][] prod;
System.out.println("----");
strassen2 a2 = new strassen2();
prod = a2.strass2(a,b);
a2.C(prod);
}
}
这是目前的结果(预期结果是显示的第一个 4x4 矩阵,实际结果是显示的最后一个 4x4 矩阵):
EXPECTED:
44 40 36 32
36 40 44 48
44 40 36 32
36 40 44 48
----
ACTUAL::
70 78 50 42
86 86 34 34
30 38 30 38
38 54 38 54
我很确定我的 helpSub() 函数可以正常工作,因为它们生成了更正的 a-h。但是,我在 strass2 递归调用中使用的参数可能有问题。很抱歉,如果它不够具体,我只是有点精疲力尽,并且很好奇是否有人看到任何明显的问题。
抱歉含糊不清,但我似乎已经解决了这个问题。我使用了这个网站上的公式来计算 p1-p7 和 c1-c4。 ([施特拉森矩阵乘法公式][1]
[1]: http://www.stoimen.com/blog/2012/11/26/computer-algorithms-strassens-matrix-multiplication/ )
实施这些公式后,乘积矩阵几乎正确,但 4 或值不正确。然后我将基本情况更改为 x 和 y 的长度等于 2 时,这似乎更正了 4 个关闭的值。对于那些好奇的人,这里是我为 strassen2 class.
修改的代码/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package pkg2a;
/**
*
* @author williamnewman
*/
public class strassen2 {
int [][] strass2(int[][] x, int[][]y){
if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
return multiply(x,y);
}
else{
int[][] a,b,c,d,e,f,g,h;
int dim = x.length/2;
System.out.println("A<B<C");
a = helpSub(0,0,x);
//C(a);
b = helpSub(0,dim,x);
//C(b);
c = helpSub(dim,0,x);
//C(c);
d = helpSub(dim,dim,x);
//C(d);
e = helpSub(0,0,y);
//C(e);
f = helpSub(0,dim,y);
//C(f);
g = helpSub(dim,0,y);
//C(g);
h = helpSub(dim,dim,y);
//C(h);
int[][] p1,p2,p3,p4,p5,p6,p7;
// createSub(x,y,a,b,c,d,e,f,g,h);
int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10;
//MODIFICATION HERE
p1 = strass2(a,subtract(f,h));
p2 = strass2(add(a,b),h);
p3 = strass2(add(c,d),e);
p4 = strass2(d,subtract(g,e));
p5 = strass2(add(a,d),add(e,h));
p6 = strass2(subtract(b,d),add(g,h));
p7 = strass2(subtract(a,c),add(e,f));
int [][] prod;
int [][] c1,c2,c3,c4;
c1 = subtract(add(p5,p4),subtract(p2,p6));
c2 = add(p1,p2);
c3 = add(p3,p4);
c4 = subtract(add(p1,p5),add(p3,p7));
//C(c1);
//System.out.println("C1::");
//C(c2);
//System.out.println("C2::");
//C(c3);
//System.out.println("C3::");
//C(c4);
//System.out.println("C4::");
prod = createProd(c1,c2,c3,c4);
return prod;
}
}
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
int[][] product = new int[c1.length*2][c1.length*2];
int mid = c1.length;
int fin = c1.length * 2;
helpProduct(0,0,mid,mid,product,c1);
helpProduct(0,mid,mid,fin,product,c2);
helpProduct(mid,0,fin,mid,product,c3);
helpProduct(mid,mid,fin,fin,product,c4);
System.out.println();
System.out.println("PRODUCT::!:");
//C(product);
return product;
}
//Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
int indR = 0;
int indC = 0;
for(int i = x; i < z1; i++){
indC = 0;
for(int j = y; j < z2; j++){
product[i][j] = a1[indR][indC];
indC++;
}
indR++;
}
}
/*
void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
int div1R = x.length/2;
int div1C = div1R;
int div2R = div1R;
int div2C = div1R;
a = helpSub(0,0,div1R,div1C,x);
// c(a);
b = helpSub(0,div1C,div1R,x[0].length,x);
//c(b);
c = helpSub(div1R,0,x.length,div1C,x);
//c(c);
d = helpSub(div1R,div1C,x.length,x[0].length,x);
//c(d);
e = helpSub(0,0,div2R,div2C,y);
//c(e);
f = helpSub(0,div2C,div2R,y[0].length,y);
// c(f);
g = helpSub(div2R,0,y.length,div2C,y);
//c(g);
h = helpSub(div2R,div2C,y.length,y[0].length,y);
// c(h);
}
*/
int[][] helpSub(int x, int y, int[][] mat){
int[][] sub = new int[mat.length/2][mat.length/2];
for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
{
sub[i1][j1] = mat[i2][j2];
// System.out.println(sub[i1][j1]);
}
return sub;
}
int[][] multiply(int[][]a,int[][]b){
MM nu = new MM(a,b);
return nu.product;
}
//Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
int [][] nu = new int[a.length][a[0].length];
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
nu[i][j] = a[i][j] + b[i][j];
}
}
return nu;
}
//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
int [][] sub = new int[a.length][a.length];
//System.out.println("made it");
int rows = a.length;
int columns = a[0].length;
for(int i = 0; i < rows; i++){
for(int j = 0; j < columns;j++){
sub[i][j] = a[i][j] - b[i][j];
}
}
return sub;
}
void C(int[][] product){
for(int i = 0; i <product.length; i++){
for(int j = 0; j < product[i].length; j++){
System.out.print(product[i][j] + " ");
}
System.out.println();
}
}
}