有什么简单的方法可以在 Fortran 中对数组的一系列排列进行求和?
Is there any simple way to realize sum over series of permutations of array in Fortran?
我试图在中编写python代码
在 Fortran 中查看我是否可以实现任何加速(-O3
有很大帮助;Ian Bush 在 Transposition of a matrix by multithread in Fortran 中的回答中的方法对我来说似乎太复杂了)。例如,
0.1 * A(l1,l2,l3,l4) + 0.2*A(l1,l2,l4,l3) + 0.3 * A(l1,l3,l2,l4)+...
如果我尝试从
扩展
Program transpose
integer, parameter :: dp = selected_real_kind(15, 307)
real(dp), dimension(:, :, :, :), allocatable :: a, b
Integer :: n1, n2, n3, n4, n, m_iter
Integer :: l1, l2, l3, l4
Integer(8) :: start, finish, rate
real(dp) :: sum_time
Write(*, *) 'n1, n2, n3, n4?'
Read(*, *) n1, n2, n3, n4
Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
Allocate( b ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
call random_init(.true., .false.)
Call Random_number( a )
m_iter = 100
b = 0.0_dp
Call system_clock( start, rate )
do n = 1, m_iter
do l4 = 1, n4
do l3 = 1, n3
do l2 = 1, n2
do l1 = 1, n1
b(l1,l2,l3,l4) = 0.1_dp*a(l1,l2,l3,l4) + 0.2_dp*a(l1,l2,l4,l3)
end do
end do
end do
end do
end do
Call system_clock( finish, rate )
sum_time = real( finish - start, dp ) / rate
write (*,*) 'all loop', sum_time/m_iter
print *, b(1,1,1,1)
End
(我试过reshape
,比嵌套循环慢)
有什么简单的方法可以包含 A(l1,l3,l2,l4)
、A(l1,l3,l4,l2)
等吗?我可以使用 Python
生成一个字符串以包含所有字符串 \
用于更改行。
一个潜在的复杂性是,如果有项 0.0 * A(l4,l3,l2,l1),我想跳过它,从 python 生成一个字符串很复杂。还有更多类似 Fortran 的解决方案吗?
另一个问题是,如果数组 A
在每个索引中具有不同的维度,例如 n1 != n2 != n3 != n4
,则某些排列可能会超出范围。在这种情况下,前置因子将为零。例如,如果 n1 = n2 = 10
、n3 = n4 = 20
,它将类似于 0.1 * A(l1,l2,l3,l4) + 0.0 * A(l1,l3,l2,l4)
。换句话说,b = 0.1*a + 0.0*reshape(a, (/n1, n2, n3, n4/), order = (/1,3,2,4/) )
,或者说 0.1*a + 0.0 * P(2,3) a
,其中 P
是置换运算符。通过检查置换前因子的绝对值低于某个阈值,求和将能够跳过该置换。
在这种情况下,前置因子将为零。求和应该跳过那种类型的排列。
已编辑:下面是 python 参考实现。我通过变量 gen_random
包含了随机和非随机版本。后者可能更容易检查。
import numpy as np
import time
import itertools as it
ref_list = [0, 1, 2, 3]
p = it.permutations(ref_list)
transpose_list = tuple(p)
n_loop = 2
na = nb = nc = nd = 30
A = np.zeros((na,nb,nc,nd))
gen_random = False
if gen_random == False:
n = 1
for la in range(na):
for lb in range(nb):
for lc in range(nc):
for ld in range(nd):
A[la,lb,lc,ld] = n
n = n + 1
else:
A = np.random.random((na,nb,nc,nd))
factor_list = [(i+1)*0.1 for i in range(24)]
time_total = 0.0
for n in range(n_loop):
sum_A = np.zeros((na,nb,nc,nd))
start_0 = time.time()
for m, t in enumerate(transpose_list):
sum_A = np.add(sum_A, factor_list[m] * np.transpose(A, transpose_list[m] ), out = sum_A)
#sum_A += factor_list[m] * np.transpose(A, transpose_list[m])
finish_0 = time.time()
time_total += finish_0 - start_0
print('level 4', time_total/n_loop)
print('Ref value', A[0,0,0,0], sum_A[0,0,0,0])
作为完整性检查,如果 A[0,0,0,0]
不为零,sum_A[0,0,0,0]/A[0,0,0,0] = 30
,由 0.1 + 0.2 +... + 2.4 = (0.1+2.4)*2.4/2=30
。虽然排列因子可以不同,但以上只是一个例子。
这是我认为在 Fortran 中使用的一种方法,它也会跳过前置因子为零的项。我不声称它是最好的,有很多方法可以做到。我也对声明其正确性犹豫不决,你提供的内容使得很难对其进行全面评估。但是当所有尺寸都相同时,它确实通过了健全性测试......你没有办法检查更一般的情况。
主要问题是 Fortran 无法创建您需要的排列,因此我编写了一个小模块,我相信它实现了与 python 相同的排序。这是从 python documentation and the algorithm to implement it from wikipedia 中获取的排序。单元测试强烈建议它完成工作。
一旦你有了它,就可以很容易地遍历每个排列,依次跳过那些权重为零的排列,因为前因子为零,或者形状不兼容。因此,除了上面的所有警告,这里是我的努力以及编译器版本和一些测试,一些检查了数组边界,一些没有检查。
请注意,即使这是正确的,我当然不会声称它有多优化 - 内存访问模式非常 non-trivial,并且优化它需要比我愿意给出的更多的思考现在,尽管我怀疑需要缓存阻塞,就像您引用的 matrix transposition question 中那样。
Module permutations_module
! Little module to handle permutations of an arbitrary sized list of integer 1, 2, 3, .... n
Implicit None
Type, Public :: permutation
Private
Integer, Dimension( : ), Allocatable, Private :: state
Contains
Procedure, Public :: init
Procedure, Public :: get
Procedure, Public :: next
End type permutation
Private
Contains
Subroutine init( p, n )
! Initalise a permutation
Class( permutation ), Intent( Out ) :: p
Integer , Intent( In ) :: n
Integer :: i
Allocate( p%state( 1:n ) )
p%state = [ ( i, i = 1, Size( p%state ) ) ]
End Subroutine init
Pure Function get( p ) Result( a )
! Get the current permutation
Class( permutation ), Intent( In ) :: p
Integer, Dimension( : ), Allocatable :: a
a = p%state
End Function get
Function next( p ) Result( finished )
! Move onto the next permutation, returning .True. if there are no more permutations in the list
Logical :: finished
Class( permutation ), Intent( InOut ) :: p
Integer :: k, l
Integer :: tmp
finished = .False.
Do k = Size( p%state ) - 1, 1, -1
If( p%state( k ) < p%state( k + 1 ) ) Exit
End Do
finished = k == 0
If( .Not. finished ) Then
Do l = Size( p%state ), k + 1, -1
If( p%state( k ) < p%state( l ) ) Exit
End Do
tmp = p%state( k )
p%state( k ) = p%state( l )
p%state( l ) = tmp
p%state( k + 1: ) = p%state( Size( p%state ):k + 1: - 1 )
End If
End Function next
End Module permutations_module
Program testit
Use, Intrinsic :: iso_fortran_env, Only : wp => real64
Use permutations_module, Only : permutation
Implicit None
Integer, Parameter :: n_iter = 100
Type( permutation ) :: p
Integer :: i
Real( wp ), Dimension( :, :, :, : ), Allocatable :: a
Real( wp ), Dimension( :, :, :, : ), Allocatable :: b
Real( wp ), Dimension( 1:Product( [ ( i, i = 1, Size( Shape( a ) ) ) ] ) ) :: c
Integer, Dimension( 1:Size( Shape( a ) ) ) :: this_permutation
Integer, Dimension( 1:Size( Shape( a ) ) ) :: sizes
Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_sizes
Integer, Dimension( 1:Size( Shape( a ) ) ) :: indices
Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_indices
Integer :: n1, n2, n3, n4
Integer :: l1, l2, l3, l4
Integer :: iter
Integer :: start, finish, rate
Logical :: finished
c = [ ( i * 0.1_wp, i = 1, Size( c ) ) ]
Write( *, * ) 'n1, n2, n3, n4?'
Read ( *, * ) n1, n2, n3, n4
Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
Allocate( b, Mold = a )
Call Random_init( .true., .false. )
Call Random_number( a )
! Make sure a( 1, 1, 1, 1 ) is not zero for the sanity check
a( 1, 1, 1, 1 ) = a( 1, 1, 1, 1 ) + 0.1_wp
Call system_clock( start, rate )
sizes = Shape( a )
b = 0.0_wp
iter_loop: Do iter = 1, n_iter
Call p%init( Size( Shape( a ) ) )
i = 0
finished = .False.
permutation_loop: Do While( .Not. finished )
i = i + 1
! Get the next permutation
finished = p%next()
! Only do it if it has any weight
If( Abs( c( i ) ) > Epsilon( c( i ) ) ) Then
! Get the current permutation
this_permutation = p%get()
! Check the shapes are compatible
permuted_sizes = sizes( this_permutation )
If( All( permuted_sizes == sizes ) ) Then
! Add in the current permutation
Do l4 = 1, n4
Do l3 = 1, n3
Do l2 = 1, n2
Do l1 = 1, n1
indices = [ l1, l2, l3, l4 ]
permuted_indices = indices( this_permutation )
b( indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 ) ) = &
b(indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 ) ) + &
c( i ) * a( permuted_indices( 1 ), permuted_indices( 2 ), &
permuted_indices( 3 ), permuted_indices( 4 ) )
End Do
End Do
End Do
End Do
End If
End If
End Do permutation_loop
End Do iter_loop
Call system_clock( finish, rate )
Write( *, * ) 'time per iteration = ', Real( finish - start ) / Real( rate ) / Real( n_iter )
Write( *, * ) 'Sanity ', b( 1, 1, 1, 1 ) / a( 1, 1, 1, 1 ) / n_iter
End Program testit
ijb@ijb-Latitude-5410:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -fcheck=all -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
10 10 10 10
time per iteration = 1.47000002E-03
Sanity 30.000000000000259
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
11 10 9 12
time per iteration = 0.00000000
Sanity 0.0000000000000000
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
30 30 30 30
time per iteration = 6.56599998E-02
Sanity 29.999999999999844
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
60 60 60 60
time per iteration = 2.46800995
Sanity 30.000000000000036
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
10 20 30 40
time per iteration = 2.00000013E-05
Sanity 0.0000000000000000
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
30 30 15 15
time per iteration = 1.50999997E-03
Sanity 1.4000000000000050
ijb@ijb-Latitude-5410:~/work/stack$
我试图在-O3
有很大帮助;Ian Bush 在 Transposition of a matrix by multithread in Fortran 中的回答中的方法对我来说似乎太复杂了)。例如,
0.1 * A(l1,l2,l3,l4) + 0.2*A(l1,l2,l4,l3) + 0.3 * A(l1,l3,l2,l4)+...
如果我尝试从
扩展Program transpose
integer, parameter :: dp = selected_real_kind(15, 307)
real(dp), dimension(:, :, :, :), allocatable :: a, b
Integer :: n1, n2, n3, n4, n, m_iter
Integer :: l1, l2, l3, l4
Integer(8) :: start, finish, rate
real(dp) :: sum_time
Write(*, *) 'n1, n2, n3, n4?'
Read(*, *) n1, n2, n3, n4
Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
Allocate( b ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
call random_init(.true., .false.)
Call Random_number( a )
m_iter = 100
b = 0.0_dp
Call system_clock( start, rate )
do n = 1, m_iter
do l4 = 1, n4
do l3 = 1, n3
do l2 = 1, n2
do l1 = 1, n1
b(l1,l2,l3,l4) = 0.1_dp*a(l1,l2,l3,l4) + 0.2_dp*a(l1,l2,l4,l3)
end do
end do
end do
end do
end do
Call system_clock( finish, rate )
sum_time = real( finish - start, dp ) / rate
write (*,*) 'all loop', sum_time/m_iter
print *, b(1,1,1,1)
End
(我试过reshape
,比嵌套循环慢)
有什么简单的方法可以包含 A(l1,l3,l2,l4)
、A(l1,l3,l4,l2)
等吗?我可以使用 Python
生成一个字符串以包含所有字符串 \
用于更改行。
一个潜在的复杂性是,如果有项 0.0 * A(l4,l3,l2,l1),我想跳过它,从 python 生成一个字符串很复杂。还有更多类似 Fortran 的解决方案吗?
另一个问题是,如果数组 A
在每个索引中具有不同的维度,例如 n1 != n2 != n3 != n4
,则某些排列可能会超出范围。在这种情况下,前置因子将为零。例如,如果 n1 = n2 = 10
、n3 = n4 = 20
,它将类似于 0.1 * A(l1,l2,l3,l4) + 0.0 * A(l1,l3,l2,l4)
。换句话说,b = 0.1*a + 0.0*reshape(a, (/n1, n2, n3, n4/), order = (/1,3,2,4/) )
,或者说 0.1*a + 0.0 * P(2,3) a
,其中 P
是置换运算符。通过检查置换前因子的绝对值低于某个阈值,求和将能够跳过该置换。
在这种情况下,前置因子将为零。求和应该跳过那种类型的排列。
已编辑:下面是 python 参考实现。我通过变量 gen_random
包含了随机和非随机版本。后者可能更容易检查。
import numpy as np
import time
import itertools as it
ref_list = [0, 1, 2, 3]
p = it.permutations(ref_list)
transpose_list = tuple(p)
n_loop = 2
na = nb = nc = nd = 30
A = np.zeros((na,nb,nc,nd))
gen_random = False
if gen_random == False:
n = 1
for la in range(na):
for lb in range(nb):
for lc in range(nc):
for ld in range(nd):
A[la,lb,lc,ld] = n
n = n + 1
else:
A = np.random.random((na,nb,nc,nd))
factor_list = [(i+1)*0.1 for i in range(24)]
time_total = 0.0
for n in range(n_loop):
sum_A = np.zeros((na,nb,nc,nd))
start_0 = time.time()
for m, t in enumerate(transpose_list):
sum_A = np.add(sum_A, factor_list[m] * np.transpose(A, transpose_list[m] ), out = sum_A)
#sum_A += factor_list[m] * np.transpose(A, transpose_list[m])
finish_0 = time.time()
time_total += finish_0 - start_0
print('level 4', time_total/n_loop)
print('Ref value', A[0,0,0,0], sum_A[0,0,0,0])
作为完整性检查,如果 A[0,0,0,0]
不为零,sum_A[0,0,0,0]/A[0,0,0,0] = 30
,由 0.1 + 0.2 +... + 2.4 = (0.1+2.4)*2.4/2=30
。虽然排列因子可以不同,但以上只是一个例子。
这是我认为在 Fortran 中使用的一种方法,它也会跳过前置因子为零的项。我不声称它是最好的,有很多方法可以做到。我也对声明其正确性犹豫不决,你提供的内容使得很难对其进行全面评估。但是当所有尺寸都相同时,它确实通过了健全性测试......你没有办法检查更一般的情况。
主要问题是 Fortran 无法创建您需要的排列,因此我编写了一个小模块,我相信它实现了与 python 相同的排序。这是从 python documentation and the algorithm to implement it from wikipedia 中获取的排序。单元测试强烈建议它完成工作。
一旦你有了它,就可以很容易地遍历每个排列,依次跳过那些权重为零的排列,因为前因子为零,或者形状不兼容。因此,除了上面的所有警告,这里是我的努力以及编译器版本和一些测试,一些检查了数组边界,一些没有检查。
请注意,即使这是正确的,我当然不会声称它有多优化 - 内存访问模式非常 non-trivial,并且优化它需要比我愿意给出的更多的思考现在,尽管我怀疑需要缓存阻塞,就像您引用的 matrix transposition question 中那样。
Module permutations_module
! Little module to handle permutations of an arbitrary sized list of integer 1, 2, 3, .... n
Implicit None
Type, Public :: permutation
Private
Integer, Dimension( : ), Allocatable, Private :: state
Contains
Procedure, Public :: init
Procedure, Public :: get
Procedure, Public :: next
End type permutation
Private
Contains
Subroutine init( p, n )
! Initalise a permutation
Class( permutation ), Intent( Out ) :: p
Integer , Intent( In ) :: n
Integer :: i
Allocate( p%state( 1:n ) )
p%state = [ ( i, i = 1, Size( p%state ) ) ]
End Subroutine init
Pure Function get( p ) Result( a )
! Get the current permutation
Class( permutation ), Intent( In ) :: p
Integer, Dimension( : ), Allocatable :: a
a = p%state
End Function get
Function next( p ) Result( finished )
! Move onto the next permutation, returning .True. if there are no more permutations in the list
Logical :: finished
Class( permutation ), Intent( InOut ) :: p
Integer :: k, l
Integer :: tmp
finished = .False.
Do k = Size( p%state ) - 1, 1, -1
If( p%state( k ) < p%state( k + 1 ) ) Exit
End Do
finished = k == 0
If( .Not. finished ) Then
Do l = Size( p%state ), k + 1, -1
If( p%state( k ) < p%state( l ) ) Exit
End Do
tmp = p%state( k )
p%state( k ) = p%state( l )
p%state( l ) = tmp
p%state( k + 1: ) = p%state( Size( p%state ):k + 1: - 1 )
End If
End Function next
End Module permutations_module
Program testit
Use, Intrinsic :: iso_fortran_env, Only : wp => real64
Use permutations_module, Only : permutation
Implicit None
Integer, Parameter :: n_iter = 100
Type( permutation ) :: p
Integer :: i
Real( wp ), Dimension( :, :, :, : ), Allocatable :: a
Real( wp ), Dimension( :, :, :, : ), Allocatable :: b
Real( wp ), Dimension( 1:Product( [ ( i, i = 1, Size( Shape( a ) ) ) ] ) ) :: c
Integer, Dimension( 1:Size( Shape( a ) ) ) :: this_permutation
Integer, Dimension( 1:Size( Shape( a ) ) ) :: sizes
Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_sizes
Integer, Dimension( 1:Size( Shape( a ) ) ) :: indices
Integer, Dimension( 1:Size( Shape( a ) ) ) :: permuted_indices
Integer :: n1, n2, n3, n4
Integer :: l1, l2, l3, l4
Integer :: iter
Integer :: start, finish, rate
Logical :: finished
c = [ ( i * 0.1_wp, i = 1, Size( c ) ) ]
Write( *, * ) 'n1, n2, n3, n4?'
Read ( *, * ) n1, n2, n3, n4
Allocate( a ( 1:n1, 1:n2, 1:n3, 1:n4 ) )
Allocate( b, Mold = a )
Call Random_init( .true., .false. )
Call Random_number( a )
! Make sure a( 1, 1, 1, 1 ) is not zero for the sanity check
a( 1, 1, 1, 1 ) = a( 1, 1, 1, 1 ) + 0.1_wp
Call system_clock( start, rate )
sizes = Shape( a )
b = 0.0_wp
iter_loop: Do iter = 1, n_iter
Call p%init( Size( Shape( a ) ) )
i = 0
finished = .False.
permutation_loop: Do While( .Not. finished )
i = i + 1
! Get the next permutation
finished = p%next()
! Only do it if it has any weight
If( Abs( c( i ) ) > Epsilon( c( i ) ) ) Then
! Get the current permutation
this_permutation = p%get()
! Check the shapes are compatible
permuted_sizes = sizes( this_permutation )
If( All( permuted_sizes == sizes ) ) Then
! Add in the current permutation
Do l4 = 1, n4
Do l3 = 1, n3
Do l2 = 1, n2
Do l1 = 1, n1
indices = [ l1, l2, l3, l4 ]
permuted_indices = indices( this_permutation )
b( indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 ) ) = &
b(indices( 1 ), indices( 2 ), indices( 3 ), indices( 4 ) ) + &
c( i ) * a( permuted_indices( 1 ), permuted_indices( 2 ), &
permuted_indices( 3 ), permuted_indices( 4 ) )
End Do
End Do
End Do
End Do
End If
End If
End Do permutation_loop
End Do iter_loop
Call system_clock( finish, rate )
Write( *, * ) 'time per iteration = ', Real( finish - start ) / Real( rate ) / Real( n_iter )
Write( *, * ) 'Sanity ', b( 1, 1, 1, 1 ) / a( 1, 1, 1, 1 ) / n_iter
End Program testit
ijb@ijb-Latitude-5410:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -fcheck=all -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
10 10 10 10
time per iteration = 1.47000002E-03
Sanity 30.000000000000259
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
11 10 9 12
time per iteration = 0.00000000
Sanity 0.0000000000000000
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -O3 -g -std=f2018 perm_mod_single_thread.f90
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
30 30 30 30
time per iteration = 6.56599998E-02
Sanity 29.999999999999844
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
60 60 60 60
time per iteration = 2.46800995
Sanity 30.000000000000036
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
10 20 30 40
time per iteration = 2.00000013E-05
Sanity 0.0000000000000000
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
n1, n2, n3, n4?
30 30 15 15
time per iteration = 1.50999997E-03
Sanity 1.4000000000000050
ijb@ijb-Latitude-5410:~/work/stack$