递归行优先的矩阵乘法

Matrix multiplication of row-major recursively

我正在编写自己的矩阵模块以进行娱乐和练习(时间和 space 复杂性无关紧要)。 现在我想实现矩阵乘法,我正在努力实现它。这可能是我使用 Haskell 的原因,而且我对它没有太多经验。 这是我的数据类型:

data Matrix a =
M {
  rows::Int,
  cols::Int,
  values::[a]
}

在数组中存储这样的 3x2 矩阵:

1 2
3 4
5 6
= [1,2,3,4,5,6]

我有一个可以使用的转置函数

transpose::(Matrix a)->(Matrix a)
transpose (M rows cols values) = M cols rows (aux values 0 0 [])
  where
   aux::[a]->Int->Int->[a]->[a]
   aux values row col transposed 
     | cols > col =
       if rows > row then 
         aux values (row+1) col (transposed ++ [valueAtIndex (M rows cols values) (row,col)])
       else aux values 0 (col+1) transposed
     | otherwise = transposed

索引数组中的元素我正在使用这个函数

valueAtIndex::(Matrix a)->(Int, Int)->a
valueAtIndex (M rows cols values) (row, col) 
  | rows <= row || cols <= col = error "indices too large for given Matrix"
  | otherwise = values !! (cols * row + col)

根据我的理解,我必须为 m1: 2x3 和 m2: 3x2 获取这样的元素

m1(0,0)*m2(0,0)+m1(0,1)*m2(0,1)+m1(0,2)*m2(0,2)
m1(0,0)*m2(1,0)+m1(0,1)*m2(1,1)+m1(0,2)*m2(1,2)
m1(1,0)*m2(0,0)+m1(1,1)*m2(0,1)+m1(1,2)*m2(0,2)
m1(1,0)*m2(1,0)+m1(1,1)*m2(1,1)+m1(1,2)*m2(2,2)

现在我需要一个接受两个矩阵的函数,rows m1 == cols m2然后以某种方式递归计算正确的矩阵。

multiplyMatrix::Num a=>(Matrix a)->(Matrix a)->(Matrix a)

首先,我不太相信这样的线性列表是个好主意。 Haskell 中的列表被建模为 链表 。所以这意味着通常访问第 k 个元素,将 运行 in O(k)。因此,对于 m×n-矩阵,这意味着它需要 O(m n) 才能访问最后一个元素。通过使用 2d 链表:一个包含链表的链表,我们将其缩小到 O(m+n),这通常更快。是的,因为您使用了更多 "cons" 数据构造函数,所以会有一些开销,但遍历量通常较低。如果您真的想要快速访问,您应该使用数组、向量等。但是还有其他设计决策需要做出。

所以我建议我们将矩阵建模为:

data Matrix a = M {
  rows :: Int,
  cols :: Int,
  values :: <b>[</b>[a]<b>]</b>
}

现在有了这个数据构造函数,我们可以将转置定义为:

transpose' :: Matrix a -> Matrix a
transpose' (M r c as) = M c r (trans as)
    where trans [] = []
          trans xs = map head xs : trans (map tail xs)

(这里我们假设列表的列表总是矩形的)

现在进行矩阵乘法。如果 AB 是两个矩阵,并且 C = A × B,那么基本上意味着ai,ji[=48=的点积 A 的第 ] 行和 B 的第 j 列。或者iA,以及j ]BTB的转置)。因此我们可以将点积定义为:

dot_prod :: Num a => [a] -> [a] -> a
dot_prod xs ys = sum (zipWith (*) xs ys)

现在只需遍历行和列,并将元素放在正确的列表中即可。喜欢:

mat_mul :: Num a => Matrix a -> Matrix a -> Matrix a
mat_mul (M r ca xss) m2 | ca /= ra = error "Invalid matrix shapes"
                        | otherwise = M r c (matmul xss)
    where (M c rb yss) = transpose m2
          matmul [] = []
          matmul (xs:xss) = generaterow yss xs : matmul xss
          generaterow [] _ = []
          generaterow (ys:yss) xs = dot_prod xs ys : generaterow yss xs