给定总和的毕达哥拉斯三元组

Pythagorean Triplet with given sum

如果毕达哥拉斯三元组等于输入,下面的代码会打印它,但问题是像 90,000 这样的大数需要很长时间才能回答。 我可以做些什么来优化以下代码? 1 ≤ n ≤ 90 000

def pythagoreanTriplet(n):

    # Considering triplets in
    # sorted order. The value
    # of first element in sorted
    # triplet can be at-most n/3.
    for i in range(1, int(n / 3) + 1):

        # The value of second element
        # must be less than equal to n/2
        for j in range(i + 1,
                       int(n / 2) + 1):

            k = n - i - j
            if (i * i + j * j == k * k):
                print(i, ", ", j, ", ",
                      k, sep="")
                return

    print("Impossible")
# Driver Code
vorodi = int(input())
pythagoreanTriplet(vorodi)

您的 source code 会进行暴力搜索以寻找解决方案,因此速度很慢。

更快的代码

def solve_pythagorean_triplets(n):
  " Solves for triplets whose sum equals n "
  solutions = []
  for a in range(1, n):
    denom = 2*(n-a)
    num = 2*a**2 + n**2 - 2*n*a
    if denom > 0 and num % denom == 0:
      c = num // denom
      b = n - a - c
      if b > a:
        solutions.append((a, b, c))

  return solutions

OP代码

修改了 OP 代码,使其 returns 所有解决方案而不是打印第一个找到的解决方案来比较性能

def pythagoreanTriplet(n): 

    # Considering triplets in  
    # sorted order. The value  
    # of first element in sorted  
    # triplet can be at-most n/3. 
    results = []
    for i in range(1, int(n / 3) + 1):  

        # The value of second element  
        # must be less than equal to n/2 
        for j in range(i + 1,  
                       int(n / 2) + 1):  

            k = n - i - j 
            if (i * i + j * j == k * k):
                results.append((i, j, k))

    return results

时机

 n     pythagoreanTriplet (OP Code)     solve_pythagorean_triplets (new)
  900   0.084 seconds                       0.039 seconds
  5000  3.130 seconds                       0.012 seconds
  90000 Timed out after several minutes     0.430 seconds

说明

函数 solve_pythagorean_triplets 是 O(n) 算法,其工作原理如下。

  1. 正在搜索:

    a^2 + b^2 = c^2 (triplet)
    a + b + c = n   (sum equals input)
    
  2. 通过搜索 a(即迭代的固定值)来解决。有了固定值,我们有两个方程和两个未知数 (b, c):

    b + c = n - a
    c^2 - b^2 = a^2
    
  3. 解法是:

    denom = 2*(n-a)
    num = 2*a**2 + n**2 - 2*n*a
    if denom > 0 and num % denom == 0:
        c = num // denom
        b = n - a - c
        if b > a:
            (a, b, c) # is a solution
    
  4. 迭代范围(1, n)得到不同的解

哟 我不知道你是否还需要答案,但希望这能对你有所帮助。

n = int(input())
ans = [(a, b, c) for a in range(1, n) for b in range(a, n) for c in range(b, n) if (a**2 + b**2 == c**2 and a + b + c == n)]
if ans:
    print(ans[0][0], ans[0][1], ans[0][2])
else:
    print("Impossible")