寻找可以写成几个不同平方数之和的数字

Finding numbers that can be written as the sum of several distinct squared numbers

最近遇到一个很有意思的问题:

Given the number N, how many combinations exist that can be written as the sum of several distinct squared numbers?

例如25可以写成:

25 = [5**2 , 3**2 + 4**2]

所以答案是 2。

但是对于31,没有解,所以答案是0。

一开始,我认为这是一个很容易解决的问题。我只是假设如果我计算小于数字 N 的平方根的平方数的每个组合,那么我应该没问题,时间方面。 (我找到了基于 this 的组合)

def combs(a):
    if len(a) == 0:
        return [[]]
    cs = []
    for c in combs(a[1:]):
        cs += [c, c+[a[0]]]
    return cs

def main(n):
    k = 0
    l = combs(range(1, int(n**0.5)+1))
    for i in l:
        s = 0
        for j in i:
            s += j**2
        if s == n:
            print(i, n)
            k += 1
    print('\nanswer: ',k)

if __name__ == '__main__':
    main(n = 25)

但是,如果你复制我的解决方案,你会发现在N > 400之后,这个问题在短时间内基本无法解决。所以,我想知道是否存在针对此问题的优化?

以下 MiniZinc 模型在不到一秒的时间内处理 n=400(55 个解决方案):

int: n = 25;
set of int: Domain = 1..ceil(pow(n, 0.5));

%  the array of decision variable decides
%  which integers between 1 and n^0.5 are added as squares
array[Domain] of var bool: b;

constraint n == sum([b[i] * i * i | i in Domain]);

output ["\(n) = "] ++ [if fix(b[i]) then "+\(i)²" else "" endif | i in Domain];

通过暴力评估 n=400 的解决方案:
枚举 1,048,576 个 20 位整数并将它们注册为产生所需总和的解。二十位中的每一位决定应该对哪个整数 1..20 进行平方和相加。 循环一百万个案例并不需要那么长时间。

实施:

from functools import cache
from math import sqrt

@cache
def _square_sums(n, max_i):
    if n == 0:
        return 1
    start = min(max_i, int(sqrt(n)))
    return sum(_square_sums(n - i**2, i - 1) for i in range(start, 0, -1))

def square_sums(n):
    return _square_sums(n, max_i=n)

测试:

>>> square_sums(25)
2

>>> square_sums(55)
1

>>> square_sums(400)
55

>>> square_sums(10000)
3296089777

>>> square_sums(100000)
2759256389896728737285379

性能特点:

n < 10000 快,但 n 大得多时就慢。

您可以使用标准的single-use“硬币找零”算法,硬币的价值是正方形。

def sum_distinct_squares(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

print(sum_distinct_squares(100000))

这在 O(n * sqrt(n)) 时间内运行,并在我的机器上用 0.016 秒解决了 n=400 的情况,在 1.715 秒内解决了 n=100000 的情况。

基于@PaulHankin 的出色回答,还可以使用快速 Numba JIT 编译器进行另一种优化:

import time
import numba
import matplotlib.pyplot as plt

@numba.jit
def sum_distinct_squares_jit(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

def sum_distinct_squares_non_jit(n):
    W = [1] + [0] * n
    for i in range(1, n+1):
        i2 = i * i
        if i2 > n:
            break
        for j in range(n, i2-1, -1):
            W[j] += W[j - i2]
    return W[n]

def loop(N, jit = None):
    if jit is True:
        sds = sum_distinct_squares_jit
    if jit is False:
        sds = sum_distinct_squares_non_jit
  
    times = []
    t_i = time.clock()
    for number in range(1, N):
        sds(number)
        times.append(time.clock() - t_i)
    plt.plot(times)

loop(N = 10_000, jit = True)
loop(N = 10_000, jit = False)
plt.legend(['JIT', 'None JIT'])
plt.show()

运行 这是一个 for 循环并绘制 JIT 和 Non-JIT 所花费的时间,我们得到:

因此,使用 JIT 的速度要快得多。然而,这是有代价的:Numba 不支持 bigint 库,对于大数字,使用 Numba 是无效的。

只有一组有限的数字不是不同平方和 (http://oeis.org/A001422)

所以这是一个 O(1) 的解决方案:

If N not in [2, 3, 6, 7, 8, 11, 12, 15, 18, 19, 22, 23, 24, 27, 28, 31, 32, 33, 43, 44, 47, 48, 60, 67, 72, 76, 92, 96, 108, 112, 128]:
  Print("N is a sum of distinct squares")