有什么简单的方法可以在 Fortran 中对数组的一系列排列进行求和?

Is there any simple way to realize sum over series of permutations of array in Fortran?

我试图在中编写python代码 在 Fortran 中查看我是否可以实现任何加速(-O3 有很大帮助;Ian Bush 在 Transposition of a matrix by multithread in Fortran 中的回答中的方法对我来说似乎太复杂了)。例如,

0.1 * A(l1,l2,l3,l4) + 0.2*A(l1,l2,l4,l3) + 0.3 * A(l1,l3,l2,l4)+...

如果我尝试从

扩展
Program transpose
  
    integer, parameter :: dp = selected_real_kind(15, 307)

    real(dp), dimension(:, :, :, :), allocatable :: a, b
    Integer :: n1, n2, n3, n4, n, m_iter
    Integer :: l1, l2, l3, l4  
    Integer(8) :: start, finish, rate
    real(dp) :: sum_time
    
    Write(*, *) 'n1, n2, n3, n4?'
    Read(*, *) n1, n2, n3, n4

    Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
    Allocate( b ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
    
    call random_init(.true., .false.)
    Call Random_number( a )
    
    m_iter = 100
    b = 0.0_dp 
    Call system_clock( start, rate )
    do n = 1, m_iter  
      do l4 = 1, n4
        do l3 = 1, n3     
          do l2 = 1, n2
            do l1 = 1, n1
              b(l1,l2,l3,l4) = 0.1_dp*a(l1,l2,l3,l4) + 0.2_dp*a(l1,l2,l4,l3)
            end do
          end do                    
        end do
      end do        
    end do
    Call system_clock( finish, rate )
    sum_time =  real( finish - start, dp ) / rate  

    write (*,*) 'all loop', sum_time/m_iter 
    print *, b(1,1,1,1)

  End 

(我试过reshape,比嵌套循环慢)

有什么简单的方法可以包含 A(l1,l3,l2,l4)A(l1,l3,l4,l2) 等吗?我可以使用 Python 生成一个字符串以包含所有字符串 \ 用于更改行。

一个潜在的复杂性是,如果有项 0.0 * A(l4,l3,l2,l1),我想跳过它,从 python 生成一个字符串很复杂。还有更多类似 Fortran 的解决方案吗?

另一个问题是,如果数组 A 在每个索引中具有不同的维度,例如 n1 != n2 != n3 != n4,则某些排列可能会超出范围。在这种情况下,前置因子将为零。例如,如果 n1 = n2 = 10n3 = n4 = 20,它将类似于 0.1 * A(l1,l2,l3,l4) + 0.0 * A(l1,l3,l2,l4)。换句话说,b = 0.1*a + 0.0*reshape(a, (/n1, n2, n3, n4/), order = (/1,3,2,4/) ) ,或者说 0.1*a + 0.0 * P(2,3) a,其中 P 是置换运算符。通过检查置换前因子的绝对值低于某个阈值,求和将能够跳过该置换。

在这种情况下,前置因子将为零。求和应该跳过那种类型的排列。

已编辑:下面是 python 参考实现。我通过变量 gen_random 包含了随机和非随机版本。后者可能更容易检查。

import numpy as np
import time
import itertools as it


ref_list = [0, 1, 2, 3]
p = it.permutations(ref_list)
transpose_list = tuple(p)

n_loop = 2
na = nb = nc = nd = 30


A = np.zeros((na,nb,nc,nd))
gen_random = False
if gen_random == False:
    n = 1
    for la in range(na):
        for lb in range(nb):
            for lc in range(nc):
                for ld in range(nd):
                   A[la,lb,lc,ld] = n
                   n = n + 1          
else:
    A = np.random.random((na,nb,nc,nd))

factor_list = [(i+1)*0.1 for i in range(24)]
time_total = 0.0

for n in range(n_loop):
    sum_A = np.zeros((na,nb,nc,nd))
    start_0 = time.time()
    for m, t in enumerate(transpose_list):
       sum_A = np.add(sum_A, factor_list[m]  * np.transpose(A, transpose_list[m] ), out = sum_A) 
       #sum_A += factor_list[m]  * np.transpose(A, transpose_list[m]) 
    finish_0 = time.time()
    time_total += finish_0 - start_0


print('level 4', time_total/n_loop) 
print('Ref value', A[0,0,0,0], sum_A[0,0,0,0]) 

作为完整性检查,如果 A[0,0,0,0] 不为零,sum_A[0,0,0,0]/A[0,0,0,0] = 30,由 0.1 + 0.2 +... + 2.4 = (0.1+2.4)*2.4/2=30。虽然排列因子可以不同,但​​以上只是一个例子。

这是我认为在 Fortran 中使用的一种方法,它也会跳过前置因子为零的项。我不声称它是最好的,有很多方法可以做到。我也对声明其正确性犹豫不决,你提供的内容使得很难对其进行全面评估。但是当所有尺寸都相同时,它确实通过了健全性测试......你没有办法检查更一般的情况。

主要问题是 Fortran 无法创建您需要的排列,因此我编写了一个小模块,我相信它实现了与 python 相同的排序。这是从 python documentation and the algorithm to implement it from wikipedia 中获取的排序。单元测试强烈建议它完成工作。

一旦你有了它,就可以很容易地遍历每个排列,依次跳过那些权重为零的排列,因为前因子为零,或者形状不兼容。因此,除了上面的所有警告,这里是我的努力以及编译器版本和一些测试,一些检查了数组边界,一些没有检查。

请注意,即使这是正确的,我当然不会声称它有多优化 - 内存访问模式非常 non-trivial,并且优化它需要比我愿意给出的更多的思考现在,尽管我怀疑需要缓存阻塞,就像您引用的 matrix transposition question 中那样。

Module permutations_module

  ! Little module to handle permutations of an arbitrary sized list of integer 1, 2, 3, .... n
  
  Implicit None

  Type, Public :: permutation
     Private
     Integer, Dimension( : ), Allocatable, Private :: state
   Contains
     Procedure, Public :: init
     Procedure, Public :: get
     Procedure, Public :: next
  End type permutation

  Private

Contains

  Subroutine init( p, n )

    ! Initalise a permutation

    Class( permutation ), Intent(   Out ) :: p
    Integer             , Intent( In    ) :: n

    Integer :: i
    
    Allocate( p%state( 1:n ) )

    p%state = [ ( i, i = 1, Size( p%state ) ) ]

  End Subroutine init

  Pure Function get( p ) Result( a )

    ! Get the current permutation

    Class( permutation ), Intent( In ) :: p

    Integer, Dimension( : ), Allocatable :: a

    a = p%state

  End Function get

  Function next( p ) Result( finished )

    ! Move onto the next permutation, returning .True. if there are no more permutations in the list
    
    Logical :: finished
  
    Class( permutation ), Intent( InOut ) :: p

    Integer :: k, l
    Integer :: tmp

    finished = .False.

    Do k = Size( p%state ) - 1, 1, -1
       If( p%state( k ) < p%state( k + 1 ) ) Exit
    End Do
    finished = k == 0

    If( .Not. finished ) Then
       
       Do l = Size( p%state ), k + 1, -1
          If( p%state( k ) < p%state( l ) ) Exit
       End Do

       tmp = p%state( k )
       p%state( k ) = p%state( l )
       p%state( l ) = tmp

       p%state( k + 1: ) = p%state( Size( p%state ):k + 1: - 1 )

    End If
    
  End Function next

End Module permutations_module

Program testit

  Use, Intrinsic :: iso_fortran_env, Only : wp => real64

  Use permutations_module, Only : permutation

  Implicit None

  Integer, Parameter :: n_iter = 100

  Type( permutation ) :: p

  Integer :: i
  
  Real( wp ), Dimension( :, :, :, : ), Allocatable :: a
  Real( wp ), Dimension( :, :, :, : ), Allocatable :: b
  
  Real( wp ), Dimension( 1:Product( [ ( i, i = 1, Size( Shape( a ) ) ) ] ) ) :: c

  Integer, Dimension( 1:Size( Shape( a ) ) ) :: this_permutation
  Integer, Dimension( 1:Size( Shape( a ) ) ) :: sizes
  Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_sizes
  Integer, Dimension( 1:Size( Shape( a ) ) ) :: indices
  Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_indices
  
  Integer :: n1, n2, n3, n4
  Integer :: l1, l2, l3, l4
  Integer :: iter
  Integer :: start, finish, rate
  
  Logical :: finished

  c = [ ( i * 0.1_wp, i = 1, Size( c ) ) ]

  Write( *, * ) 'n1, n2, n3, n4?'
  Read ( *, * ) n1, n2, n3, n4

  Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
  Allocate( b, Mold = a ) 

  Call Random_init( .true., .false. )
  Call Random_number( a )
  ! Make sure a( 1, 1, 1, 1 ) is not zero for the sanity check
  a( 1, 1, 1, 1 ) = a( 1, 1, 1, 1 ) + 0.1_wp

  Call system_clock( start, rate )
  sizes = Shape( a )
  b = 0.0_wp

  iter_loop: Do iter = 1, n_iter

     Call p%init( Size( Shape( a ) ) )
     i = 0
     finished = .False.
     permutation_loop: Do While( .Not. finished )
        i = i + 1

        ! Get the next permutation
        finished = p%next()

        ! Only do it if it has any weight
        If( Abs( c( i ) ) > Epsilon( c( i ) ) ) Then

           ! Get the current permutation
           this_permutation = p%get()

           ! Check the shapes are compatible
           permuted_sizes = sizes( this_permutation )

           If( All( permuted_sizes == sizes ) ) Then

              ! Add in the current permutation
              Do l4 = 1, n4
                 Do l3 = 1, n3     
                    Do l2 = 1, n2
                       Do l1 = 1, n1
                          indices = [ l1, l2, l3, l4 ]
                          permuted_indices = indices( this_permutation )
                          b( indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 ) ) = &
                               b(indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 )  ) + &
                               c( i ) * a( permuted_indices( 1 ), permuted_indices( 2 ), &
                               permuted_indices( 3 ), permuted_indices( 4 ) )
                       End Do
                    End Do
                 End Do
              End Do
              
           End If
              
        End If
     End Do permutation_loop
     
  End Do iter_loop
  
  Call system_clock( finish, rate )
  Write( *, * ) 'time per iteration = ', Real( finish - start ) / Real( rate ) / Real( n_iter )
  Write( *, * ) 'Sanity ', b( 1, 1, 1, 1 ) / a( 1, 1, 1, 1 ) / n_iter
     
End Program testit
ijb@ijb-Latitude-5410:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

ijb@ijb-Latitude-5410:~/work/stack$ gfortran  -fcheck=all -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
10 10 10 10
 time per iteration =    1.47000002E-03
 Sanity    30.000000000000259     
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
11 10 9 12
 time per iteration =    0.00000000    
 Sanity    0.0000000000000000     
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
30 30 30 30
 time per iteration =    6.56599998E-02
 Sanity    29.999999999999844     
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
60 60 60 60
 time per iteration =    2.46800995    
 Sanity    30.000000000000036     
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
10 20 30 40
 time per iteration =    2.00000013E-05
 Sanity    0.0000000000000000     
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
 n1, n2, n3, n4?
30 30 15 15
 time per iteration =    1.50999997E-03
 Sanity    1.4000000000000050     
ijb@ijb-Latitude-5410:~/work/stack$