三元表示中的数字快速求和 (Python)
Fast sum of digits in a ternary representation (Python)
我定义了一个函数
def enumerateSpin(n):
s = []
for a in range(0,3**n):
ternary_rep = np.base_repr(a,3)
k = len(ternary_rep)
r = (n-k)*'0'+ternary_rep
if sum(map(int,r)) == n:
s.append(r)
return s
我查看一个数字 0 <= a < 3^N 并询问它在三元表示中的数字总和是否等于某个值。我通过首先将数字转换为其三元表示形式的字符串来做到这一点。我正在填充零,因为我想存储一个固定长度表示的列表,我以后可以将其用于进一步的计算(即两个元素之间的逐位比较)。
现在 np.base_repr
和 sum(map(int,#))
在我的计算机上分别需要大约 5 us,这意味着迭代大约需要 10 us,我正在寻找一种可以完成我所做的事情的方法但是快 10 倍。
(编辑:注意左侧填充零)
(Edit2:事后看来,最终表示最好是整数元组而不是字符串)。
(Edit3:对于那些想知道的人,代码的目的是枚举具有相同总 S_z 值的自旋 1 链的状态。)
这是一种多处理方法。它将节省更多时间,问题规模越大
import multiprocessing as mp
def filter(n, qIn, qOut): # this is the function that will be parallelized
nums = range(3**n)
answer = []
for low,high in iter(qIn.get, None):
for num in nums[low:high]:
r = np.base_repr(num, 3) # ternary representation
if sum(int(i) for i in r) == num: # this is your check
answer.append('0'*(n-len(r)) +r) # turn it into a fixed length
qOut.put(answer)
qOut.put(None)
def enumerateSpin(n): # this is the primary entry point
numProcs = mp.cpu_count()-1 # fiddle to taste
chunkSize = n//numProcs
qIn, qOut = [mp.Queue() for _ in range(2)]
procs = [mp.Process(target=filter, args=(n, qIn, qOut)) for _ in range(numProcs)]
for p in procs: p.start()
for i in range(0, 3**n, chunkSize): # chunkify your numbers so that IPC is more efficient
qIn.put((i, i+chunkSize))
for p in procs: qIn.put(None)
answer = []
done = 0
while done < len(procs):
t = qOut.get()
if t is None:
done += 1
continue
answer.extend(t)
for p in procs: p.terminate()
return answer
您可以使用itertools.product
生成数字然后转换为字符串表示:
import itertools as it
def new(n):
s = []
for digits in it.product((0, 1, 2), repeat=n):
if sum(digits) == n:
s.append(''.join(str(x) for x in digits))
return s
这给了我大约 7 倍的加速:
In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在 Python 3.9.0 (IPython 7.20.0) (Linux).
上测试
上述过程,使用 it.product
,还生成了我们通过推理知道它们不符合条件的数字(这是所有数字的一半的情况,因为数字总和必须等于位数)。对于 n
个数字,我们可以计算 2
、1
和 0
的各种数字计数,最终总和为 n
。然后我们可以生成这些数字的所有 distinct permutations,因此只生成相关数字:
import itertools as it
from more_itertools import distinct_permutations
def new2(n):
all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
return (''.join(digits) for digits in all_digits)
特别是对于大量 n
这提供了额外的显着加速:
In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
请注意,上述解决方案 new
和 new2
具有 O(1) 内存缩放(将 new
更改为 yield
而不是 append
)。
一般来说,要获取特定基数的数字,我们可以这样做:
while num > 0:
digit = num % base
num //= base
print(digit)
当 运行 这与 num = 14, base = 3
我们得到:
2
1
1
也就是说14的三进制是112
我们可以将其提取到方法 digits(num, base)
中,并且仅在我们实际需要将数字转换为字符串时才使用 np.base_repr(a,3)
:
def enumerateSpin(n):
s = []
for a in range(0,3**n):
if sum(digits(a, 3)) == n:
ternary_rep = np.base_repr(a,3)
k = len(ternary_rep)
r = (n-k)*'0'+ternary_rep
s.append(r)
return s
enumerateSpin(4)
的输出:
['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']
通过将所有计算委托给 numpy 以利用矢量化处理,可以实现 10 倍的改进:
def eSpin(n):
nums = np.arange(3**n,dtype=np.int)
base3 = nums // (3**np.arange(n))[:,None] % 3
matches = np.sum(base3,axis=0) == n
digits = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
return [f"{a:0{n}}" for a in digits]
工作原理(eSpin(3) 示例):
nums
是一个包含最多 3**n
个数字的数组
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]
base3
将其转换为附加维度的3进制数字:
[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
[0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]
matches
标识 base3 数字之和与 n
匹配的列
[0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]
digits
将匹配列转换为由 base3 数字组成的 base 10 数字
[ 12 21 102 111 120 201 210]
最后匹配的 (base10) 数字用前导零格式化。
表现:
from timeit import timeit
count = 1
print(enumerateSpin(10)==eSpin(10)) # True
t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec
t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec
元组版本:
def eSpin2(n):
nums = np.arange(3**n,dtype=np.int)
base3 = nums// (3**np.arange(n))[:,None] % 3
matches = np.sum(base3,axis=0) == n
return [*map(tuple,base3[:,matches].T)]
eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]
[编辑] 一种更快的方法 (比 enumerateSpin 快 40 到 80 倍)
使用动态规划和记忆可以提供更好的性能:
@lru_cache()
def eSpin(n,base=3,target=None):
if target is None: target = n
if target == 0: return [(0,)*n]
if target>base**n-1: return []
if n==1: return [(target,)]
result = []
for d in range(min(base,target+1)):
result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
return result
t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec
eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec
我定义了一个函数
def enumerateSpin(n):
s = []
for a in range(0,3**n):
ternary_rep = np.base_repr(a,3)
k = len(ternary_rep)
r = (n-k)*'0'+ternary_rep
if sum(map(int,r)) == n:
s.append(r)
return s
我查看一个数字 0 <= a < 3^N 并询问它在三元表示中的数字总和是否等于某个值。我通过首先将数字转换为其三元表示形式的字符串来做到这一点。我正在填充零,因为我想存储一个固定长度表示的列表,我以后可以将其用于进一步的计算(即两个元素之间的逐位比较)。
现在 np.base_repr
和 sum(map(int,#))
在我的计算机上分别需要大约 5 us,这意味着迭代大约需要 10 us,我正在寻找一种可以完成我所做的事情的方法但是快 10 倍。
(编辑:注意左侧填充零)
(Edit2:事后看来,最终表示最好是整数元组而不是字符串)。
(Edit3:对于那些想知道的人,代码的目的是枚举具有相同总 S_z 值的自旋 1 链的状态。)
这是一种多处理方法。它将节省更多时间,问题规模越大
import multiprocessing as mp
def filter(n, qIn, qOut): # this is the function that will be parallelized
nums = range(3**n)
answer = []
for low,high in iter(qIn.get, None):
for num in nums[low:high]:
r = np.base_repr(num, 3) # ternary representation
if sum(int(i) for i in r) == num: # this is your check
answer.append('0'*(n-len(r)) +r) # turn it into a fixed length
qOut.put(answer)
qOut.put(None)
def enumerateSpin(n): # this is the primary entry point
numProcs = mp.cpu_count()-1 # fiddle to taste
chunkSize = n//numProcs
qIn, qOut = [mp.Queue() for _ in range(2)]
procs = [mp.Process(target=filter, args=(n, qIn, qOut)) for _ in range(numProcs)]
for p in procs: p.start()
for i in range(0, 3**n, chunkSize): # chunkify your numbers so that IPC is more efficient
qIn.put((i, i+chunkSize))
for p in procs: qIn.put(None)
answer = []
done = 0
while done < len(procs):
t = qOut.get()
if t is None:
done += 1
continue
answer.extend(t)
for p in procs: p.terminate()
return answer
您可以使用itertools.product
生成数字然后转换为字符串表示:
import itertools as it
def new(n):
s = []
for digits in it.product((0, 1, 2), repeat=n):
if sum(digits) == n:
s.append(''.join(str(x) for x in digits))
return s
这给了我大约 7 倍的加速:
In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在 Python 3.9.0 (IPython 7.20.0) (Linux).
上测试上述过程,使用 it.product
,还生成了我们通过推理知道它们不符合条件的数字(这是所有数字的一半的情况,因为数字总和必须等于位数)。对于 n
个数字,我们可以计算 2
、1
和 0
的各种数字计数,最终总和为 n
。然后我们可以生成这些数字的所有 distinct permutations,因此只生成相关数字:
import itertools as it
from more_itertools import distinct_permutations
def new2(n):
all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
return (''.join(digits) for digits in all_digits)
特别是对于大量 n
这提供了额外的显着加速:
In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
请注意,上述解决方案 new
和 new2
具有 O(1) 内存缩放(将 new
更改为 yield
而不是 append
)。
一般来说,要获取特定基数的数字,我们可以这样做:
while num > 0:
digit = num % base
num //= base
print(digit)
当 运行 这与 num = 14, base = 3
我们得到:
2
1
1
也就是说14的三进制是112
我们可以将其提取到方法 digits(num, base)
中,并且仅在我们实际需要将数字转换为字符串时才使用 np.base_repr(a,3)
:
def enumerateSpin(n):
s = []
for a in range(0,3**n):
if sum(digits(a, 3)) == n:
ternary_rep = np.base_repr(a,3)
k = len(ternary_rep)
r = (n-k)*'0'+ternary_rep
s.append(r)
return s
enumerateSpin(4)
的输出:
['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']
通过将所有计算委托给 numpy 以利用矢量化处理,可以实现 10 倍的改进:
def eSpin(n):
nums = np.arange(3**n,dtype=np.int)
base3 = nums // (3**np.arange(n))[:,None] % 3
matches = np.sum(base3,axis=0) == n
digits = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
return [f"{a:0{n}}" for a in digits]
工作原理(eSpin(3) 示例):
nums
是一个包含最多 3**n
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]
base3
将其转换为附加维度的3进制数字:
[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
[0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]
matches
标识 base3 数字之和与 n
[0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]
digits
将匹配列转换为由 base3 数字组成的 base 10 数字
[ 12 21 102 111 120 201 210]
最后匹配的 (base10) 数字用前导零格式化。
表现:
from timeit import timeit
count = 1
print(enumerateSpin(10)==eSpin(10)) # True
t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec
t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec
元组版本:
def eSpin2(n):
nums = np.arange(3**n,dtype=np.int)
base3 = nums// (3**np.arange(n))[:,None] % 3
matches = np.sum(base3,axis=0) == n
return [*map(tuple,base3[:,matches].T)]
eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]
[编辑] 一种更快的方法 (比 enumerateSpin 快 40 到 80 倍)
使用动态规划和记忆可以提供更好的性能:
@lru_cache()
def eSpin(n,base=3,target=None):
if target is None: target = n
if target == 0: return [(0,)*n]
if target>base**n-1: return []
if n==1: return [(target,)]
result = []
for d in range(min(base,target+1)):
result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
return result
t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec
eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec