Python:分而治之递归矩阵乘法
Python: Divide and Conquer Recursive Matrix Multiplication
我正在尝试实现分而治之矩阵乘法(8 递归版本不是 Strassen)。我以为我已经弄清楚了,但是它产生了带有太多嵌套列表和错误值的奇怪输出。我怀疑问题是我如何对 8 次递归求和,但我不确定。
def multiMatrix(x,y):
n = len(x)
if n == 1:
return x[0][0] * y[0][0]
else:
a = [[col for col in row[:len(row)/2]] for row in x[:len(x)/2]]
b = [[col for col in row[len(row)/2:]] for row in x[:len(x)/2]]
c = [[col for col in row[:len(row)/2]] for row in x[len(x)/2:]]
d = [[col for col in row[len(row)/2:]] for row in x[len(x)/2:]]
e = [[col for col in row[:len(row)/2]] for row in y[:len(y)/2]]
f = [[col for col in row[len(row)/2:]] for row in y[:len(y)/2]]
g = [[col for col in row[:len(row)/2]] for row in y[len(y)/2:]]
h = [[col for col in row[len(row)/2:]] for row in y[len(y)/2:]]
ae = multiMatrix(a,e)
bg = multiMatrix(b,g)
af = multiMatrix(a,f)
bh = multiMatrix(b,h)
ce = multiMatrix(c,e)
dg = multiMatrix(d,g)
cf = multiMatrix(c,f)
dh = multiMatrix(d,h)
c = [[ae+bg,af+bh],[ce+dg,cf+dh]]
return c
a = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
b = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
print multiMatrix(a,b)
你的怀疑是正确的,你的矩阵仍然是列表,所以添加它们只会得到一个更长的列表。
尝试使用类似这样的东西
def matrix_add(a, b):
return [[ea+eb for ea, eb in zip(*rowpair)] for rowpair in zip(a, b)]
在您的代码中。
加入块:
def join_horiz(a, b):
return [rowa + rowb for rowa, rowb in zip(a,b)]
def join_vert(a, b):
return a+b
最后,为了使它们协同工作,我认为您必须将 1 的特殊情况更改为
return [[x[0][0] * y[0][0]]]
编辑:
我刚刚意识到这只适用于二维的幂。否则你将不得不处理非方矩阵,并且 x
是 1 x 某些东西并且你的特殊情况将不起作用。所以你还必须检查 len(x[0]) (如果 n > 0)。
我正在尝试实现分而治之矩阵乘法(8 递归版本不是 Strassen)。我以为我已经弄清楚了,但是它产生了带有太多嵌套列表和错误值的奇怪输出。我怀疑问题是我如何对 8 次递归求和,但我不确定。
def multiMatrix(x,y):
n = len(x)
if n == 1:
return x[0][0] * y[0][0]
else:
a = [[col for col in row[:len(row)/2]] for row in x[:len(x)/2]]
b = [[col for col in row[len(row)/2:]] for row in x[:len(x)/2]]
c = [[col for col in row[:len(row)/2]] for row in x[len(x)/2:]]
d = [[col for col in row[len(row)/2:]] for row in x[len(x)/2:]]
e = [[col for col in row[:len(row)/2]] for row in y[:len(y)/2]]
f = [[col for col in row[len(row)/2:]] for row in y[:len(y)/2]]
g = [[col for col in row[:len(row)/2]] for row in y[len(y)/2:]]
h = [[col for col in row[len(row)/2:]] for row in y[len(y)/2:]]
ae = multiMatrix(a,e)
bg = multiMatrix(b,g)
af = multiMatrix(a,f)
bh = multiMatrix(b,h)
ce = multiMatrix(c,e)
dg = multiMatrix(d,g)
cf = multiMatrix(c,f)
dh = multiMatrix(d,h)
c = [[ae+bg,af+bh],[ce+dg,cf+dh]]
return c
a = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
b = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
print multiMatrix(a,b)
你的怀疑是正确的,你的矩阵仍然是列表,所以添加它们只会得到一个更长的列表。
尝试使用类似这样的东西
def matrix_add(a, b):
return [[ea+eb for ea, eb in zip(*rowpair)] for rowpair in zip(a, b)]
在您的代码中。
加入块:
def join_horiz(a, b):
return [rowa + rowb for rowa, rowb in zip(a,b)]
def join_vert(a, b):
return a+b
最后,为了使它们协同工作,我认为您必须将 1 的特殊情况更改为
return [[x[0][0] * y[0][0]]]
编辑:
我刚刚意识到这只适用于二维的幂。否则你将不得不处理非方矩阵,并且 x
是 1 x 某些东西并且你的特殊情况将不起作用。所以你还必须检查 len(x[0]) (如果 n > 0)。