C 中的 Strassen 乘法
Strassen's multiplication in C
请看下面的代码:
#include<stdio.h>
#include<stdlib.h>
int **divide(int **Matrix,int n,int position)
{
int i,j;
int **Partition=malloc(sizeof(*Partition)*n);
for(i=0;i<n;i++)
{
Partition[i]=calloc(n,sizeof(*Partition[i]));
}
if(position==1)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j];
}
}
}
else if(position==2)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j+n/2];
}
}
}
else if(position==3)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j];
}
}
}
else if(position==4)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j+n/2];
}
}
}
return Partition;
}
int **allocate(int n)
{
int **newmatrix=malloc(sizeof(*newmatrix)*n);
for(int i=0;i<n;i++)
{
newmatrix[i]=calloc(n, sizeof(*newmatrix[i]));
}
return newmatrix;
}
void mfree(int **matrix,int n) {
for (int i=0;i<n;i++) {
free(matrix[i]);
}
free(matrix);
}
int **add(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]+b[i][j];
}
}
return c;
}
int **subtract(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]-b[i][j];
}
}
return c;
}
void print(int **Matrix,int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
printf("%d ",Matrix[i][j]);
}
printf("\n");
}
}
int **Strassens(int **A,int **B,int n)
{
int **C=allocate(n);
if(n==1)
{
C[0][0]=A[0][0]*B[0][0];
}
else
{ //Allocate the submatrices
int **a11=allocate(n/2);
int **a12=allocate(n/2);
int **a21=allocate(n/2);
int **a22=allocate(n/2);
int **b11=allocate(n/2);
int **b12=allocate(n/2);
int **b21=allocate(n/2);
int **b22=allocate(n/2);
a11=divide(A,n,1);
a12=divide(A,n,2);
a21=divide(A,n,3);
a22=divide(A,n,4);
b11=divide(B,n,1);
b12=divide(B,n,2);
b21=divide(B,n,3);
b22=divide(B,n,4);
int **s1=subtract(b12,b22,n/2);
int **s2=add(a11,a12,n/2);
int **s3=add(a21,a22,n/2);
int **s4=subtract(b21,b11,n/2);
int **s5=add(a11,a22,n/2);
int **s6=add(b11,b22,n/2);
int **s7=subtract(a12,a22,n/2);
int **s8=add(b21,b22,n/2);
int **s9=subtract(a11,a21,n/2);
int **s10=add(b11,a12,n/2);
int **p1=Strassens(a11,s1,n/2);
int **p2=Strassens(s2,b22,n/2);
int **p3=Strassens(s3,b11,n/2);
int **p4=Strassens(a22,s4,n/2);
int **p5=Strassens(s5,s6,n/2);
int **p6=Strassens(s7,s8,n/2);
int **p7=Strassens(s9,s10,n/2);
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
for(int i=0;i<n/2;i++)
{
for(int j=0;j<n/2;j++)
{
C[i][j]=c11[i][j];
C[i][j+n/2]=c12[i][j];
C[i+n/2][j]=c21[i][j];
C[i+n/2][j+n/2]=c22[i][j];
}
}
}
return C;
}
int main()
{
int n=8; //Dimension of the square matrix, n*n;
int **A=allocate(n);
int **B=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
A[i][j]=j+1;
B[i][j]=j+1;
}
}
printf("Matrix A:\n");
print(A,n);
printf("Matrix B: \n");
print(B,n);
printf("\n...Performing Multiplication with Strassen's...\nMatrix A*B:\n");
int **C = Strassens(A,B,n);
print(C,n);
mfree(C,n);
}
我知道这是一个非常愚蠢的问题,数学有问题。但我不知道我哪里出错了。
问题是,当我将两个具有相等值的矩阵相乘时,我得到了想要的结果,但这不适用于具有不同值的矩阵。
例如,看看输出:
Matrix A:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
和
Matrix A:
1 2 3 4 5 6 7 8
2 3 4 5 6 7 8 9
3 4 5 6 7 8 9 10
4 5 6 7 8 9 10 11
5 6 7 8 9 10 11 12
6 7 8 9 10 11 12 13
7 8 9 10 11 12 13 14
8 9 10 11 12 13 14 15
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
316 424 484 528 460 440 372 288
300 398 452 426 412 366 308 154
268 360 414 446 348 312 246 134
252 254 382 424 300 126 182 112
156 232 260 272 404 352 252 136
140 150 228 34 356 334 188 138
108 168 70 54 292 224 246 118
92 -122 38 24 244 222 182 104
:_) 抱歉。
这部分有一个轻微的数学错误:
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
用
替换c11和c22
int **c11=subtract(add(add(p5,p4,n/2),p6,n/2),p2,n/2);
...
int **c22=subtract(subtract(add(p5,p1,n/2),p3,n/2),p7,n/2);
更正数学错误。
请看下面的代码:
#include<stdio.h>
#include<stdlib.h>
int **divide(int **Matrix,int n,int position)
{
int i,j;
int **Partition=malloc(sizeof(*Partition)*n);
for(i=0;i<n;i++)
{
Partition[i]=calloc(n,sizeof(*Partition[i]));
}
if(position==1)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j];
}
}
}
else if(position==2)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j+n/2];
}
}
}
else if(position==3)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j];
}
}
}
else if(position==4)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j+n/2];
}
}
}
return Partition;
}
int **allocate(int n)
{
int **newmatrix=malloc(sizeof(*newmatrix)*n);
for(int i=0;i<n;i++)
{
newmatrix[i]=calloc(n, sizeof(*newmatrix[i]));
}
return newmatrix;
}
void mfree(int **matrix,int n) {
for (int i=0;i<n;i++) {
free(matrix[i]);
}
free(matrix);
}
int **add(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]+b[i][j];
}
}
return c;
}
int **subtract(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]-b[i][j];
}
}
return c;
}
void print(int **Matrix,int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
printf("%d ",Matrix[i][j]);
}
printf("\n");
}
}
int **Strassens(int **A,int **B,int n)
{
int **C=allocate(n);
if(n==1)
{
C[0][0]=A[0][0]*B[0][0];
}
else
{ //Allocate the submatrices
int **a11=allocate(n/2);
int **a12=allocate(n/2);
int **a21=allocate(n/2);
int **a22=allocate(n/2);
int **b11=allocate(n/2);
int **b12=allocate(n/2);
int **b21=allocate(n/2);
int **b22=allocate(n/2);
a11=divide(A,n,1);
a12=divide(A,n,2);
a21=divide(A,n,3);
a22=divide(A,n,4);
b11=divide(B,n,1);
b12=divide(B,n,2);
b21=divide(B,n,3);
b22=divide(B,n,4);
int **s1=subtract(b12,b22,n/2);
int **s2=add(a11,a12,n/2);
int **s3=add(a21,a22,n/2);
int **s4=subtract(b21,b11,n/2);
int **s5=add(a11,a22,n/2);
int **s6=add(b11,b22,n/2);
int **s7=subtract(a12,a22,n/2);
int **s8=add(b21,b22,n/2);
int **s9=subtract(a11,a21,n/2);
int **s10=add(b11,a12,n/2);
int **p1=Strassens(a11,s1,n/2);
int **p2=Strassens(s2,b22,n/2);
int **p3=Strassens(s3,b11,n/2);
int **p4=Strassens(a22,s4,n/2);
int **p5=Strassens(s5,s6,n/2);
int **p6=Strassens(s7,s8,n/2);
int **p7=Strassens(s9,s10,n/2);
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
for(int i=0;i<n/2;i++)
{
for(int j=0;j<n/2;j++)
{
C[i][j]=c11[i][j];
C[i][j+n/2]=c12[i][j];
C[i+n/2][j]=c21[i][j];
C[i+n/2][j+n/2]=c22[i][j];
}
}
}
return C;
}
int main()
{
int n=8; //Dimension of the square matrix, n*n;
int **A=allocate(n);
int **B=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
A[i][j]=j+1;
B[i][j]=j+1;
}
}
printf("Matrix A:\n");
print(A,n);
printf("Matrix B: \n");
print(B,n);
printf("\n...Performing Multiplication with Strassen's...\nMatrix A*B:\n");
int **C = Strassens(A,B,n);
print(C,n);
mfree(C,n);
}
我知道这是一个非常愚蠢的问题,数学有问题。但我不知道我哪里出错了。 问题是,当我将两个具有相等值的矩阵相乘时,我得到了想要的结果,但这不适用于具有不同值的矩阵。 例如,看看输出:
Matrix A:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
和
Matrix A:
1 2 3 4 5 6 7 8
2 3 4 5 6 7 8 9
3 4 5 6 7 8 9 10
4 5 6 7 8 9 10 11
5 6 7 8 9 10 11 12
6 7 8 9 10 11 12 13
7 8 9 10 11 12 13 14
8 9 10 11 12 13 14 15
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
316 424 484 528 460 440 372 288
300 398 452 426 412 366 308 154
268 360 414 446 348 312 246 134
252 254 382 424 300 126 182 112
156 232 260 272 404 352 252 136
140 150 228 34 356 334 188 138
108 168 70 54 292 224 246 118
92 -122 38 24 244 222 182 104
:_) 抱歉。 这部分有一个轻微的数学错误:
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
用
替换c11和c22int **c11=subtract(add(add(p5,p4,n/2),p6,n/2),p2,n/2);
...
int **c22=subtract(subtract(add(p5,p1,n/2),p3,n/2),p7,n/2);
更正数学错误。