Fortran 的 MKL dgemm 给出零结果

Fortran's MKL dgemm is giving zero result

在下面的 Fortran 程序中,我使用英特尔的 MKL 库使用 dgemm 执行矩阵乘法。最初,我使用了 matmul 子程序并得到了正确的结果。当我在下面的循环中将 matmul 翻译成 dgemm 时,我得到了所有零向量而不是正确的输出。感谢您的帮助。

program spectral_norm    
implicit none
!
integer, parameter :: n = 5500, dp = kind(0.0d0)
real(dp), allocatable :: A(:, :), u(:), v(:), Au(:), Av(:)
integer :: i, j

allocate(u(n), v(n), A(n, n), Au(n), Av(n))

do j = 1, n
    do i = 1, n
        A(i, j) = Ac(i, j)
    end do
end do

u = 1
do i = 1, 10
    call dgemm('N','N', n, 1, n, 1.0, A,  n, u, n, 0.0, Au, n) 
    call dgemm('N','N', n, 1, n, 1.0, Au, n, A, n, 0.0, v,  n) 
    call dgemm('N','N', n, 1, n, 1.0, A,  n, v, n, 0.0, Av, n) 
    call dgemm('N','N', n, 1, n, 1.0, Av, n, A, n, 0.0, u,  n) 
    !v = matmul(matmul(A, u), A)
    !u = matmul(matmul(A, v), A)
end do

write(*, "(f0.9)") sqrt(dot_product(u, v) / dot_product(v, v))

contains

pure real(dp) function Ac(i, j) result(r)
integer, intent(in) :: i, j
r = 1._dp / ((i+j-2) * (i+j-1)/2 + i)
end function

end program spectral_norm

这给出 NaN,而 matmul 的正确输出是 1.274224153

好的,谢谢大家的建议。我想我找到了错误的根源。有两种情况乘法顺序颠倒了,应该是A * AuA * Av。这是因为 A 的顺序是 n x n,而 AuAv 的顺序都是 n x 1。因此,由于尺寸不匹配,我们无法乘以 Au * AAv * A。我在下面发布了更正后的版本。

program spectral_norm    
implicit none
!
integer, parameter :: n = 5500, dp = kind(0.d0)
real(dp), allocatable :: A(:,:), u(:), v(:), Au(:), Av(:)
integer :: i, j

allocate(u(n), v(n), A(n, n), Au(n), Av(n))

do j = 1, n
    do i = 1, n
        A(i, j) = Ac(i, j)
    end do
end do

u = 1
do i = 1, 10
    call dgemm('N', 'N', n, 1, n, 1._dp, A, n, u,  n, 0._dp, Au, n)
    call dgemm('T', 'N', n, 1, n, 1._dp, A, n, Au, n, 0._dp, v,  n)
    call dgemm('N', 'N', n, 1, n, 1._dp, A, n, v,  n, 0._dp, Av, n)
    call dgemm('T', 'N', n, 1, n, 1._dp, A, n, Av, n, 0._dp, u,  n)
end do

write(*, "(f0.9)") sqrt(dot_product(u, v) / dot_product(v, v))

contains

pure real(dp) function Ac(i, j) result(r)
    integer, intent(in) :: i, j
    r = 1._dp / ((i+j-2) * (i+j-1)/2 + i)
end function

end program spectral_norm

这给出了正确的结果:

1.274224153
 Elapsed time   0.5150000     seconds