求解 x^2 + y^2 + z^2 = N 得到 x, y, z 的所有唯一组合
Solving for x^2 + y^2 + z^2 = N to get all the unique combination of x, y, z
问题是我必须找到所有可能的整数组合 (x, y, z) 来满足等式 x^2 + y^2 + z^2 = N 当你给定一个整数 N . 你必须找到所有唯一的元组 (x, y, z)。例如,如果元组之一是 (1, 2, 1),则 (2, 1, 1) 不再是唯一的。
def find(n):
## max of x can be sqrt of n
n1 = math.ceil(n ** (1/2))
lst = []
for i in range(1, n1):
if (i ** 2) >= n:
break
for j in range(1, (i + 1)):
if (i ** 2 + j ** 2) >= n:
break
if (i ** 2 + i ** 2 + i ** 2) < n:
break
for k in range(1, (j + 1)):
if (i ** 2 + j ** 2 + j ** 2) < n:
break
if (i ** 2 + j ** 2 + k ** 2) > n:
break
if i ** 2 + j ** 2 + k ** 2 == n:
a = [i, j, k]
lst.append(a)
return lst
我试着做了一些优化。例如,当不可能再满足方程式时,我将停止执行循环。但它仍然没有优化。我有一个带有 8 位整数的测试用例,例如 12345678。我的代码需要很长时间才能求解 8 位整数。
还有什么其他的优化可以吗?
谢谢!
更新
这是最近的工作代码,它提供唯一值并使用 list
。
import math
def find(n):
# max of x can be sqrt of n
n1 = int(math.ceil(math.sqrt(n)))
lst = list()
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
lst.append((i, j, tz))
lst.sort()
j = 0
i = 0
while i < len(lst):
lst[j] = lst[i]
if lst[i] != lst[j]:
i = i + 1
while i < len(lst) and lst[i] == lst[j]:
i = i + 1
j = j + 1
lst = lst[:j]
return lst
print find(12345678)
您的解决方案绝对可以优化。不需要 k
迭代的循环。您可以使用简单的数学运算来摆脱第三个循环。
我的 python
技能有点生疏,但下面的代码有效。
import math
def find(n):
# max of x can be sqrt of n
n1 = int(math.ceil(math.sqrt(n)))
lst = []
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
a = [i, j, tz]
lst.append(a)
return lst
print find(12345678)
已测试:here
对于唯一值,您可以使用不同的数据结构,例如维护唯一值的集合,或者放置另一个 if-check 以查看是否已经存在类似的条目。
lst = set()
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = (i, j, tz)
lst.add(a)
已测试:here
或
如果您想继续该列表,您可以添加另一个检查并仅在当前值集不在列表中时附加到列表。
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = [i, j, tz]
if not(a in lst):
lst.append(a)
两者都是以增加时间复杂度为代价的。但是,使用 set
的 IMO 非常适合,因为它的平均情况复杂度是 O(1)
,而在 list
中搜索元素的存在是 O(n)
.
以下是您可以在@vishal-wadhwa 的基础上进行的一些小优化。
您可以检查任何平方除以 8 得到的余数是 0 或 4 或 1。
例如,使用它您可以立即判断如果 n 除以 8 有余数 7,则无解。
如果 n 是 4 的倍数,则只有偶数个解,所以你可以除以 4 来求解。
最后,如果 n 除以 4 得到余数 1、2 或 3,则解中必须有 1、2 或 3 个奇数,以及 2、1 或 0 个偶数。
下面的代码使用所有这些来偷工减料。
import math
def new_find(n):
if n == 0:
return [(0, 0, 0)]
f = 1
while n % 4 == 0:
n //= 4
f *= 2
if f > 1:
return [(f*a, f*b, f*c) for (a, b, c) in find(n)]
if n % 8 == 7:
return []
offs = 0 if n % 4 == 1 else 1
split = 3 if n % 4 == 3 else 2
sol = []
for i in range(offs, n, 2):
if i*i > n // split:
break
for j in range(i, n, 2):
if j*j > (n - i*i) / (split-1):
break
rem = n - i*i - j*j
rs = int(math.sqrt(rem))
if rem == rs*rs:
sol.append((i, j, rs))
return sol
def orig_find(n):
lst = set()
n1 = int(math.ceil(math.sqrt(n)))
for i in range(0, n1):
for j in range(i, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = (i, j, tz)
lst.add(a)
return lst
def check_equal(a, b):
import operator
return all(map(operator.eq, sorted(map(sorted, a)), sorted(map(sorted, b))))
print(check_equal(new_find(12345678), orig_find(12345678)))
# True
问题是我必须找到所有可能的整数组合 (x, y, z) 来满足等式 x^2 + y^2 + z^2 = N 当你给定一个整数 N . 你必须找到所有唯一的元组 (x, y, z)。例如,如果元组之一是 (1, 2, 1),则 (2, 1, 1) 不再是唯一的。
def find(n):
## max of x can be sqrt of n
n1 = math.ceil(n ** (1/2))
lst = []
for i in range(1, n1):
if (i ** 2) >= n:
break
for j in range(1, (i + 1)):
if (i ** 2 + j ** 2) >= n:
break
if (i ** 2 + i ** 2 + i ** 2) < n:
break
for k in range(1, (j + 1)):
if (i ** 2 + j ** 2 + j ** 2) < n:
break
if (i ** 2 + j ** 2 + k ** 2) > n:
break
if i ** 2 + j ** 2 + k ** 2 == n:
a = [i, j, k]
lst.append(a)
return lst
我试着做了一些优化。例如,当不可能再满足方程式时,我将停止执行循环。但它仍然没有优化。我有一个带有 8 位整数的测试用例,例如 12345678。我的代码需要很长时间才能求解 8 位整数。
还有什么其他的优化可以吗? 谢谢!
更新
这是最近的工作代码,它提供唯一值并使用 list
。
import math
def find(n):
# max of x can be sqrt of n
n1 = int(math.ceil(math.sqrt(n)))
lst = list()
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
lst.append((i, j, tz))
lst.sort()
j = 0
i = 0
while i < len(lst):
lst[j] = lst[i]
if lst[i] != lst[j]:
i = i + 1
while i < len(lst) and lst[i] == lst[j]:
i = i + 1
j = j + 1
lst = lst[:j]
return lst
print find(12345678)
您的解决方案绝对可以优化。不需要
k
迭代的循环。您可以使用简单的数学运算来摆脱第三个循环。
我的 python
技能有点生疏,但下面的代码有效。
import math
def find(n):
# max of x can be sqrt of n
n1 = int(math.ceil(math.sqrt(n)))
lst = []
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
a = [i, j, tz]
lst.append(a)
return lst
print find(12345678)
已测试:here
对于唯一值,您可以使用不同的数据结构,例如维护唯一值的集合,或者放置另一个 if-check 以查看是否已经存在类似的条目。
lst = set()
for i in range(1, n1):
for j in range(i + 1, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = (i, j, tz)
lst.add(a)
已测试:here
或
如果您想继续该列表,您可以添加另一个检查并仅在当前值集不在列表中时附加到列表。
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = [i, j, tz]
if not(a in lst):
lst.append(a)
两者都是以增加时间复杂度为代价的。但是,使用 set
的 IMO 非常适合,因为它的平均情况复杂度是 O(1)
,而在 list
中搜索元素的存在是 O(n)
.
以下是您可以在@vishal-wadhwa 的基础上进行的一些小优化。
您可以检查任何平方除以 8 得到的余数是 0 或 4 或 1。 例如,使用它您可以立即判断如果 n 除以 8 有余数 7,则无解。
如果 n 是 4 的倍数,则只有偶数个解,所以你可以除以 4 来求解。
最后,如果 n 除以 4 得到余数 1、2 或 3,则解中必须有 1、2 或 3 个奇数,以及 2、1 或 0 个偶数。
下面的代码使用所有这些来偷工减料。
import math
def new_find(n):
if n == 0:
return [(0, 0, 0)]
f = 1
while n % 4 == 0:
n //= 4
f *= 2
if f > 1:
return [(f*a, f*b, f*c) for (a, b, c) in find(n)]
if n % 8 == 7:
return []
offs = 0 if n % 4 == 1 else 1
split = 3 if n % 4 == 3 else 2
sol = []
for i in range(offs, n, 2):
if i*i > n // split:
break
for j in range(i, n, 2):
if j*j > (n - i*i) / (split-1):
break
rem = n - i*i - j*j
rs = int(math.sqrt(rem))
if rem == rs*rs:
sol.append((i, j, rs))
return sol
def orig_find(n):
lst = set()
n1 = int(math.ceil(math.sqrt(n)))
for i in range(0, n1):
for j in range(i, n1):
if i**2 + j**2 >= n:
break
tz = int(math.sqrt(n - i**2 - j**2))
if tz ** 2 == n - i**2 - j**2:
#This if-block makes the elements in tuple, sorted.
#so that set can compare two values for equality
#because (1, 2, 3) != (2, 3, 1)
if tz < i:
i, j, tz = tz, i, j
elif tz < j:
j, tz = tz, j
a = (i, j, tz)
lst.add(a)
return lst
def check_equal(a, b):
import operator
return all(map(operator.eq, sorted(map(sorted, a)), sorted(map(sorted, b))))
print(check_equal(new_find(12345678), orig_find(12345678)))
# True