三元表示中的数字快速求和 (Python)

Fast sum of digits in a ternary representation (Python)

我定义了一个函数

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        ternary_rep = np.base_repr(a,3)
        k = len(ternary_rep)
        r = (n-k)*'0'+ternary_rep
        if sum(map(int,r)) == n:
            s.append(r)
    return s

我查看一个数字 0 <= a < 3^N 并询问它在三元表示中的数字总和是否等于某个值。我通过首先将数字转换为其三元表示形式的字符串来做到这一点。我正在填充零,因为我想存储一个固定长度表示的列表,我以后可以将其用于进一步的计算(即两个元素之间的逐位比较)。

现在 np.base_reprsum(map(int,#)) 在我的计算机上分别需要大约 5 us,这意味着迭代大约需要 10 us,我正在寻找一种可以完成我所做的事情的方法但是快 10 倍。

(编辑:注意左侧填充零)

(Edit2:事后看来,最终表示最好是整数元组而不是字符串)。

(Edit3:对于那些想知道的人,代码的目的是枚举具有相同总 S_z 值的自旋 1 链的状态。)

这是一种多处理方法。它将节省更多时间,问题规模越大

import multiprocessing as mp


def filter(n, qIn, qOut):  # this is the function that will be parallelized
    nums = range(3**n)
    answer = []
    for low,high in iter(qIn.get, None):
        for num in nums[low:high]:
            r = np.base_repr(num, 3)  # ternary representation
            if sum(int(i) for i in r) == num:  # this is your check
                answer.append('0'*(n-len(r)) +r)  # turn it into a fixed length
    qOut.put(answer)
    qOut.put(None)


def enumerateSpin(n):  # this is the primary entry point
    numProcs = mp.cpu_count()-1  # fiddle to taste
    chunkSize = n//numProcs

    qIn, qOut = [mp.Queue() for _ in range(2)]
    procs = [mp.Process(target=filter, args=(n, qIn, qOut)) for _ in range(numProcs)]

    for p in procs: p.start()
    for i in range(0, 3**n, chunkSize):  # chunkify your numbers so that IPC is more efficient
        qIn.put((i, i+chunkSize))
    for p in procs: qIn.put(None)

    answer = []
    done = 0
    while done < len(procs):
        t = qOut.get()
        if t is None:
            done += 1
            continue
        answer.extend(t)

    for p in procs: p.terminate()

    return answer

您可以使用itertools.product生成数字然后转换为字符串表示:

import itertools as it

def new(n):
    s = []
    for digits in it.product((0, 1, 2), repeat=n):
        if sum(digits) == n:
            s.append(''.join(str(x) for x in digits))
    return s

这给了我大约 7 倍的加速:

In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

在 Python 3.9.0 (IPython 7.20.0) (Linux).

上测试

上述过程,使用 it.product,还生成了我们通过推理知道它们不符合条件的数字(这是所有数字的一半的情况,因为数字总和必须等于位数)。对于 n 个数字,我们可以计算 210 的各种数字计数,最终总和为 n。然后我们可以生成这些数字的所有 distinct permutations,因此只生成相关数字:

import itertools as it
from more_itertools import distinct_permutations

def new2(n):
    all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
    all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
    return (''.join(digits) for digits in all_digits)

特别是对于大量 n 这提供了额外的显着加速:

In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

请注意,上述解决方案 newnew2 具有 O(1) 内存缩放(将 new 更改为 yield 而不是 append)。

一般来说,要获取特定基数的数字,我们可以这样做:

while num > 0:
    digit = num % base
    num //= base
    print(digit)

当 运行 这与 num = 14, base = 3 我们得到:

2
1
1

也就是说14的三进制是112
我们可以将其提取到方法 digits(num, base) 中,并且仅在我们实际需要将数字转换为字符串时才使用 np.base_repr(a,3)

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        if sum(digits(a, 3)) == n:
            ternary_rep = np.base_repr(a,3)
            k = len(ternary_rep)
            r = (n-k)*'0'+ternary_rep
            s.append(r)
    return s

enumerateSpin(4)的输出:

['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']

通过将所有计算委托给 numpy 以利用矢量化处理,可以实现 10 倍的改进:

def eSpin(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums // (3**np.arange(n))[:,None] % 3
    matches = np.sum(base3,axis=0) == n
    digits  = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
    return [f"{a:0{n}}" for a in digits]   

工作原理(eSpin(3) 示例):

nums 是一个包含最多 3**n

个数字的数组
   [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]  

base3将其转换为附加维度的3进制数字:

[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
 [0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
 [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]

matches 标识 base3 数字之和与 n

匹配的列
 [0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]

digits 将匹配列转换为由 base3 数字组成的 base 10 数字

 [ 12  21 102 111 120 201 210]

最后匹配的 (base10) 数字用前导零格式化。

表现:

from timeit import timeit
count = 1

print(enumerateSpin(10)==eSpin(10)) # True

t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec

t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec

元组版本:

def eSpin2(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums// (3**np.arange(n))[:,None]  % 3
    matches = np.sum(base3,axis=0) == n
    return [*map(tuple,base3[:,matches].T)]

eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]

[编辑] 一种更快的方法 (比 enumerateSpin 快 40 到 80 倍)

使用动态规划和记忆可以提供更好的性能:

@lru_cache()
def eSpin(n,base=3,target=None):
    if target is None: target = n
    if target == 0: return [(0,)*n]
    if target>base**n-1: return []
    if n==1: return [(target,)]
    result = []
    for d in range(min(base,target+1)):
        result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
    return result

t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec

eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec