如何在此循环中利用置换对称性?
How to exploit permutational symmetry in this loop?
我有一个标量函数 f(a,b,c,d)
具有以下排列对称性
f(a,b,c,d) = f(c,d,a,b) = -f(b,a,d,c) = -f(d,c,b,a)
我正在使用它来完全填充 4D 数组。以下代码(使用 python/NumPy)有效:
A = np.zeros((N,N,N,N))
for a in range(N):
for b in range(N):
for c in range(N):
for d in range(N):
A[a,b,c,d] = f(a,b,c,d)
但显然我想利用对称性来减少这部分代码的执行时间。我试过:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
执行时间减半。但是对于我尝试的最后一个对称性:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
if ((a >= b) or (c >= d)):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
这行得通,但没有给我接近两倍加速的另一个因素。我不认为这是正确的理由,但不明白为什么。
我怎样才能更好地利用这里的这种特殊的排列对称性?
有趣的问题!
对于N=3
,应该有4个元素的81种组合。
用你的循环,你创造了 156.
看来重复的主要来源是(a >= b) or (c >= d)
中的or
,太宽容了。不过,(a >= b) and (c >= d)
限制太多了。
不过你可以比较一下 a + c >= b + d
。要获得几毫秒(如果有的话),您可以在第三个循环中将 a + c
保存为 ac
:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b+d):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
使用此代码,我们创建了 112 种组合,因此与使用您的方法相比,重复项更少,但可能仍存在一些优化。
更新
这是我用来计算创建组合数的代码:
from itertools import product
N = 3
ab = 0
all_combinations = set(product(range(N), repeat=4))
zeroes = ((x, x, y, y) for x, y in product(range(N), repeat=2))
calculated = list()
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b + d) and not (a == b and c == d):
calculated.append((a, b, c, d))
calculated.append((c, d, a, b))
calculated.append((b, a, d, c))
calculated.append((d, c, b, a))
missing = all_combinations - set(calculated) - set(zeroes)
if missing:
print "Some sets weren't calculated :"
for s in missing:
print s
else:
print "All cases were covered"
print len(calculated)
有了and not (a==b and c==d)
,这个数字就降到了88。
我有一个标量函数 f(a,b,c,d)
具有以下排列对称性
f(a,b,c,d) = f(c,d,a,b) = -f(b,a,d,c) = -f(d,c,b,a)
我正在使用它来完全填充 4D 数组。以下代码(使用 python/NumPy)有效:
A = np.zeros((N,N,N,N))
for a in range(N):
for b in range(N):
for c in range(N):
for d in range(N):
A[a,b,c,d] = f(a,b,c,d)
但显然我想利用对称性来减少这部分代码的执行时间。我试过:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
执行时间减半。但是对于我尝试的最后一个对称性:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
for d in range(N):
cd += 1
if ab >= cd:
if ((a >= b) or (c >= d)):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
这行得通,但没有给我接近两倍加速的另一个因素。我不认为这是正确的理由,但不明白为什么。
我怎样才能更好地利用这里的这种特殊的排列对称性?
有趣的问题!
对于N=3
,应该有4个元素的81种组合。
用你的循环,你创造了 156.
看来重复的主要来源是(a >= b) or (c >= d)
中的or
,太宽容了。不过,(a >= b) and (c >= d)
限制太多了。
不过你可以比较一下 a + c >= b + d
。要获得几毫秒(如果有的话),您可以在第三个循环中将 a + c
保存为 ac
:
A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b+d):
A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
使用此代码,我们创建了 112 种组合,因此与使用您的方法相比,重复项更少,但可能仍存在一些优化。
更新
这是我用来计算创建组合数的代码:
from itertools import product
N = 3
ab = 0
all_combinations = set(product(range(N), repeat=4))
zeroes = ((x, x, y, y) for x, y in product(range(N), repeat=2))
calculated = list()
for a in range(N):
for b in range(N):
ab += 1
cd = 0
for c in range(N):
ac = a + c
for d in range(N):
cd += 1
if (ab >= cd and ac >= b + d) and not (a == b and c == d):
calculated.append((a, b, c, d))
calculated.append((c, d, a, b))
calculated.append((b, a, d, c))
calculated.append((d, c, b, a))
missing = all_combinations - set(calculated) - set(zeroes)
if missing:
print "Some sets weren't calculated :"
for s in missing:
print s
else:
print "All cases were covered"
print len(calculated)
有了and not (a==b and c==d)
,这个数字就降到了88。