带减法的子集求和算法
Algorithm for subset-sum with subtraction
我有一个子集和问题,您可以在其中添加或减去项。例如,如果我有五个项(1、2、3、4、5),我想知道有多少种方法可以 add/subtract 使这些项成为 7:
- 3 + 4
- 2 + 5
- 1 + 2 + 4
- 5 - 2 + 4
- 等等
我在Python写了一些代码,但是一旦有很多项就很慢:
import itertools
from collections import OrderedDict
sum_answer = 1
terms = {"T1": 1, "T2": -2, "T3": 3, "T4": -4, "T5": 5}
numlist = [v for v in terms.values()]
zerlist = [x for x in itertools.repeat(0, len(numlist))]
opslist = [item for item in itertools.product((1, -1), repeat=len(numlist))]
res_list = []
for i in range(1, len(numlist)):
combos = itertools.combinations(numlist, i)
for x in combos:
prnlist = list(x) + zerlist[:len(numlist) - len(x)]
for o in opslist:
operators = list(o)
result = []
res_sum = 0
for t in range(len(prnlist)):
if operators[t] == 1:
ops = "+"
else:
ops = "-"
if prnlist[t] != 0:
result += [ops, list(terms.keys())[list(terms.values()).index(prnlist[t])]]
res_sum += operators[t] * prnlist[t]
if sum_answer == res_sum:
res_list += [" ".join(result)]
for ans in OrderedDict.fromkeys(res_list).keys():
print(ans)
我意识到一百万个嵌套循环的效率非常低,那么有什么地方可以用更好的算法来加速吗?
类似于"regular"子集和问题——你用DP解决问题的地方,你也会在这里使用它,但需要多一种可能性——减少当前元素添加它。
f(0,i) = 1 //successive subset
f(x,0) = 0 x>0 //failure subset
f(x,i) = f(x+element[i],i-1) + f(x-element[i],i-1) + f(x,i-1)
^^^
This is the added option for substraction
将其转换为自下而上的 DP 解决方案时,您需要创建一个大小为 (SUM+1) * (2n+1)
的矩阵,其中 SUM
是所有元素的总和,n
是元素数量。
我认为您的想法基本上是正确的:生成每个术语的组合,然后求和,看看是否命中。不过,您可以优化代码。
问题是,一旦生成 1 + 2
,您会发现它与您想要的总和不匹配,因此将其丢弃。但是,如果您向其中添加 4
,它就是一个解决方案。但是,在生成 1 + 2 + 4
之前,您不会得到该解决方案,届时您将从头开始计算总和。您还可以为每个组合从头开始添加运算符,出于同样的原因,这也会做很多冗余工作。
您还使用了很多列表操作,这可能很慢。
我会这样做:
def solve(terms_list, stack, current_s, desired_s):
if len(terms_list) == 0:
if current_s == desired_s:
print(stack)
return
for w in [0, 1, -1]: # ignore term (0), add it (1), subtract it (-1)
stack.append(w)
solve(terms_list[1:], stack, current_s + w * terms_list[0], desired_s)
stack.pop()
初始调用例如solve([1,2,3,4,5], [], 0, 7)
.
请注意,这具有复杂性 O(3^n)
(有点,请继续阅读),因为每个术语都可以添加、减去或忽略。
我实际实现的复杂度是 O(n*3^n)
,因为递归调用复制了 terms_list
参数。然而,您可以避免这种情况,但我想让代码更简单,并将其留作练习。您也可以避免在打印之前构建实际表达式,而是逐步构建它,但您可能需要更多参数。
但是,O(3^n)
仍然很多,无论您做什么,您都不应该指望它对大型 n
表现很好。
现在您正在尝试暴力破解一行中所有可能的字段值组合(然后针对其他行对每个组合进行有效性测试)。
我想您有很多行数据要处理;我建议您通过获取一堆行(至少与您要解决的字段一样多)并应用近似矩阵求解器来利用它,例如 numpy.linalg.lstsq
.
这有很多重要的优点:
允许您理智地处理舍入错误问题(如果您的任何字段是非整数则必需)
允许您轻松处理系数不在{-1, 0, 1}
中的字段,即系数可能类似于0.12
[=46=的税率]
使用完全支持的代码,您无需调试或维护
使用高度优化的代码 运行 相当快 (** 最有可能,取决于你的 numpy 编译时使用的选项)
具有更好的时间复杂度(类似于 O(n ** 2.8) 而不是 O(3 ** n)),这意味着它应该扩展到更多的字段
所以,一些测试数据:
import numpy as np
# generate test data
def make_test_data(coeffs, mean=20.0, base=0.05):
w = len(coeffs) # number of fields
h = int(1.5 * w) # number of rows of data
rows = np.random.exponential(mean - base, (h, w)) + base
totals = data.dot(coeffs)
return rows.round(2), totals.round(2)
这给了我们类似的东西
>>> rows, totals = make_test_data([0, 1, 1, 0, -1, 0.12])
>>> print(rows)
[[ 1.45 17.63 22.54 5.54 37.06 1.47]
[ 11.71 80.43 26.43 18.48 11.08 8.8 ]
[ 16.09 11.34 63.74 3.31 13.2 13.35]
[ 11.96 12.17 10.23 8.15 73.3 0.42]
[ 4.03 8.01 20.84 21.46 2.76 18.98]
[ 3.24 6.6 35.06 23.17 9.03 8.58]
[ 25.05 33.72 6.82 0.49 46.76 12.21]
[ 70.27 1.48 23.05 0.69 31.11 43.13]
[ 9.04 10.45 15.08 4.32 52.94 11.13]]
>>> print(totals)
[ 3.29 96.84 63.48 -50.85 28.37 33.66 -4.75 -1.4 -26.07]
和求解器代码,
>>> sol = np.linalg.lstsq(rows, totals) # one line!
>>> print(sol[0]) # note the solutions are not *exact*
[ -1.485730e-04 1.000072e+00 9.999334e-01 -7.992023e-05 -9.999552e-01 1.203379e-01]
>>> print(sol[0].round(3)) # but they are *very* close
[ 0. 1. 1. 0. -1. 0.12]
我有一个子集和问题,您可以在其中添加或减去项。例如,如果我有五个项(1、2、3、4、5),我想知道有多少种方法可以 add/subtract 使这些项成为 7:
- 3 + 4
- 2 + 5
- 1 + 2 + 4
- 5 - 2 + 4
- 等等
我在Python写了一些代码,但是一旦有很多项就很慢:
import itertools
from collections import OrderedDict
sum_answer = 1
terms = {"T1": 1, "T2": -2, "T3": 3, "T4": -4, "T5": 5}
numlist = [v for v in terms.values()]
zerlist = [x for x in itertools.repeat(0, len(numlist))]
opslist = [item for item in itertools.product((1, -1), repeat=len(numlist))]
res_list = []
for i in range(1, len(numlist)):
combos = itertools.combinations(numlist, i)
for x in combos:
prnlist = list(x) + zerlist[:len(numlist) - len(x)]
for o in opslist:
operators = list(o)
result = []
res_sum = 0
for t in range(len(prnlist)):
if operators[t] == 1:
ops = "+"
else:
ops = "-"
if prnlist[t] != 0:
result += [ops, list(terms.keys())[list(terms.values()).index(prnlist[t])]]
res_sum += operators[t] * prnlist[t]
if sum_answer == res_sum:
res_list += [" ".join(result)]
for ans in OrderedDict.fromkeys(res_list).keys():
print(ans)
我意识到一百万个嵌套循环的效率非常低,那么有什么地方可以用更好的算法来加速吗?
类似于"regular"子集和问题——你用DP解决问题的地方,你也会在这里使用它,但需要多一种可能性——减少当前元素添加它。
f(0,i) = 1 //successive subset
f(x,0) = 0 x>0 //failure subset
f(x,i) = f(x+element[i],i-1) + f(x-element[i],i-1) + f(x,i-1)
^^^
This is the added option for substraction
将其转换为自下而上的 DP 解决方案时,您需要创建一个大小为 (SUM+1) * (2n+1)
的矩阵,其中 SUM
是所有元素的总和,n
是元素数量。
我认为您的想法基本上是正确的:生成每个术语的组合,然后求和,看看是否命中。不过,您可以优化代码。
问题是,一旦生成 1 + 2
,您会发现它与您想要的总和不匹配,因此将其丢弃。但是,如果您向其中添加 4
,它就是一个解决方案。但是,在生成 1 + 2 + 4
之前,您不会得到该解决方案,届时您将从头开始计算总和。您还可以为每个组合从头开始添加运算符,出于同样的原因,这也会做很多冗余工作。
您还使用了很多列表操作,这可能很慢。
我会这样做:
def solve(terms_list, stack, current_s, desired_s):
if len(terms_list) == 0:
if current_s == desired_s:
print(stack)
return
for w in [0, 1, -1]: # ignore term (0), add it (1), subtract it (-1)
stack.append(w)
solve(terms_list[1:], stack, current_s + w * terms_list[0], desired_s)
stack.pop()
初始调用例如solve([1,2,3,4,5], [], 0, 7)
.
请注意,这具有复杂性 O(3^n)
(有点,请继续阅读),因为每个术语都可以添加、减去或忽略。
我实际实现的复杂度是 O(n*3^n)
,因为递归调用复制了 terms_list
参数。然而,您可以避免这种情况,但我想让代码更简单,并将其留作练习。您也可以避免在打印之前构建实际表达式,而是逐步构建它,但您可能需要更多参数。
但是,O(3^n)
仍然很多,无论您做什么,您都不应该指望它对大型 n
表现很好。
现在您正在尝试暴力破解一行中所有可能的字段值组合(然后针对其他行对每个组合进行有效性测试)。
我想您有很多行数据要处理;我建议您通过获取一堆行(至少与您要解决的字段一样多)并应用近似矩阵求解器来利用它,例如 numpy.linalg.lstsq
.
这有很多重要的优点:
允许您理智地处理舍入错误问题(如果您的任何字段是非整数则必需)
允许您轻松处理系数不在
[=46=的税率]{-1, 0, 1}
中的字段,即系数可能类似于0.12
使用完全支持的代码,您无需调试或维护
使用高度优化的代码 运行 相当快 (** 最有可能,取决于你的 numpy 编译时使用的选项)
具有更好的时间复杂度(类似于 O(n ** 2.8) 而不是 O(3 ** n)),这意味着它应该扩展到更多的字段
所以,一些测试数据:
import numpy as np
# generate test data
def make_test_data(coeffs, mean=20.0, base=0.05):
w = len(coeffs) # number of fields
h = int(1.5 * w) # number of rows of data
rows = np.random.exponential(mean - base, (h, w)) + base
totals = data.dot(coeffs)
return rows.round(2), totals.round(2)
这给了我们类似的东西
>>> rows, totals = make_test_data([0, 1, 1, 0, -1, 0.12])
>>> print(rows)
[[ 1.45 17.63 22.54 5.54 37.06 1.47]
[ 11.71 80.43 26.43 18.48 11.08 8.8 ]
[ 16.09 11.34 63.74 3.31 13.2 13.35]
[ 11.96 12.17 10.23 8.15 73.3 0.42]
[ 4.03 8.01 20.84 21.46 2.76 18.98]
[ 3.24 6.6 35.06 23.17 9.03 8.58]
[ 25.05 33.72 6.82 0.49 46.76 12.21]
[ 70.27 1.48 23.05 0.69 31.11 43.13]
[ 9.04 10.45 15.08 4.32 52.94 11.13]]
>>> print(totals)
[ 3.29 96.84 63.48 -50.85 28.37 33.66 -4.75 -1.4 -26.07]
和求解器代码,
>>> sol = np.linalg.lstsq(rows, totals) # one line!
>>> print(sol[0]) # note the solutions are not *exact*
[ -1.485730e-04 1.000072e+00 9.999334e-01 -7.992023e-05 -9.999552e-01 1.203379e-01]
>>> print(sol[0].round(3)) # but they are *very* close
[ 0. 1. 1. 0. -1. 0.12]