使用带有 numpy 矩阵的 Strassen 算法的输出矩阵不正确
Incorrect output matrix using Strassen's algorithm with numpy matrices
我正在尝试使用 Python 3 和 numpy 矩阵实现 CLRS 中描述的 Strassen 矩阵乘法算法。
问题是输出矩阵 C 作为零矩阵而不是正确的乘积返回。我不确定为什么我的实现不起作用,但怀疑这与每次递归调用时创建 C 矩阵有关。对于我做错了什么以及如何解决它的任何解释,我将不胜感激。
谢谢!
import numpy as np
def strassen(A,B):
n = A.shape[0]
C = np.zeros((n*n), dtype=np.int).reshape(n,n)
if n == 1:
C[0][0] = A[0][0] * B[0][0]
else:
k = int(n/2)
A11,A21,A12,A22 = A[:k,:k], A[k:, :k], A[:k, k:], A[k:, k:]
B11,B21,B12,B22 = B[:k,:k], B[k:, :k], B[:k, k:], B[k:, k:]
C11,C21,C12,C22 = C[:k,:k], C[k:, :k], C[:k, k:], C[k:, k:]
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10= B11 + B12
P1 = strassen(A11, S1)
P2 = strassen(S2, B22)
P3 = strassen(S3, B11)
P4 = strassen(A22, S4)
P5 = strassen(S5, S6)
P6 = strassen(S7, S8)
P7 = strassen(S9, S10)
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
return C
好的,我通过简单地用新值更新切片 C[:k,:k] 而不是创建新变量 C11、C12 ..ect 来让它工作。
因为这样做会创建一个新矩阵,而不是对原始矩阵 C 的引用。
我正在尝试使用 Python 3 和 numpy 矩阵实现 CLRS 中描述的 Strassen 矩阵乘法算法。
问题是输出矩阵 C 作为零矩阵而不是正确的乘积返回。我不确定为什么我的实现不起作用,但怀疑这与每次递归调用时创建 C 矩阵有关。对于我做错了什么以及如何解决它的任何解释,我将不胜感激。
谢谢!
import numpy as np
def strassen(A,B):
n = A.shape[0]
C = np.zeros((n*n), dtype=np.int).reshape(n,n)
if n == 1:
C[0][0] = A[0][0] * B[0][0]
else:
k = int(n/2)
A11,A21,A12,A22 = A[:k,:k], A[k:, :k], A[:k, k:], A[k:, k:]
B11,B21,B12,B22 = B[:k,:k], B[k:, :k], B[:k, k:], B[k:, k:]
C11,C21,C12,C22 = C[:k,:k], C[k:, :k], C[:k, k:], C[k:, k:]
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10= B11 + B12
P1 = strassen(A11, S1)
P2 = strassen(S2, B22)
P3 = strassen(S3, B11)
P4 = strassen(A22, S4)
P5 = strassen(S5, S6)
P6 = strassen(S7, S8)
P7 = strassen(S9, S10)
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
return C
好的,我通过简单地用新值更新切片 C[:k,:k] 而不是创建新变量 C11、C12 ..ect 来让它工作。 因为这样做会创建一个新矩阵,而不是对原始矩阵 C 的引用。