Python 中的线段树实现
Segment tree implementation in Python
我正在解决 this problem using segment tree 但我遇到时间限制错误。
下面是我的范围最小查询的原始代码,通过在我的代码中将 min
更改为 max
可以解决上述问题。我不知道如何提高代码的性能。你能帮我解决它的性能问题吗?
t = [None] * 2 * 7 # n is length of list
def build(a, v, start, end):
'''
A recursive function that constructs Segment Tree for list a.
v is the starting node
start and end are the index of array
'''
n = len(a)
if start == end:
t[v] = a[start]
else:
mid = (start + end) / 2
build(a, v * 2, start, mid) # v*2 is left child of parent v
# v*2+1 is the right child of parent v
build(a, v * 2 + 1, mid + 1, end)
t[v] = min(t[2 * v], t[2 * v + 1])
return t
print build([18, 17, 13, 19, 15, 11, 20], 1, 0, 6)
inf = 10**9 + 7
def range_minimum_query(node, segx, segy, qx, qy):
'''
returns the minimum number in range(qx,qy)
segx and segy represent the segment index
'''
if qx > segy or qy < segx: # query out of range
return inf
elif segx >= qx and segy <= qy: # query range inside segment range
return t[node]
else:
return min(range_minimum_query(node * 2, segx, (segx + segy) / 2, qx, qy), range_minimum_query(node * 2 + 1, ((segx + segy) / 2) + 1, segy, qx, qy))
print range_minimum_query(1, 1, 7, 1, 3)
# returns 13
这可以迭代实现吗?
您可以尝试使用生成器,因为您可以绕过很多限制。但是您没有提供清楚显示您的性能问题的数据集 - 您能提供有问题的数据集吗?
这里你可以试试:
t=[None]*2*7
inf=10**9+7
def build_generator(a, v, start, end):
n = len(a)
if start == end:
t[v] = a[start]
else:
mid = (start + end) / 2
next(build_generator(a, v * 2, start, mid))
next(build_generator(a, v * 2 + 1, mid + 1, end))
t[v] = min(t[2 * v], t[2 * v + 1])
yield t
def range_minimum_query_generator(node,segx,segy,qx,qy):
if qx > segy or qy < segx:
yield inf
elif segx >= qx and segy <= qy:
yield t[node]
else:
min_ = min(
next(range_minimum_query_generator(node*2,segx,(segx+segy)/2,qx,qy)),
next(range_minimum_query_generator(node*2+1,((segx+segy)/2)+1,segy,qx,qy))
)
yield min_
next(build_generator([18,17,13,19,15,11,20],1,0,6))
value = next(range_minimum_query_generator(1, 1, 7, 1, 3))
print(value)
编辑
事实上,这可能无法解决您的问题。还有另一种方法可以解决任何递归限制(如 D. Beazley 在其生成器教程中所述 - https://www.youtube.com/watch?v=D1twn9kLmYg&t=9588s 时间码 2h00 左右)
语言选择
首先,如果你使用 python,你可能永远不会通过评分。如果您在这里查看所有过去解决方案的状态,http://www.spoj.com/status/GSS1/start=0,您会发现几乎所有被接受的解决方案都是用 C++ 编写的。我认为您别无选择,只能使用 C++。请注意时间限制是 0.115s-0.230s。这是一个 "only-for-C/C++" 时间限制。对于接受其他语言解决方案的问题,时间限制将是一个 "round" 数字,例如 1 秒。 Python 在这种类型的环境中比 C++ 慢 2-4 倍。
线段树实现问题...?
其次,我不确定您的代码是否真的在构建线段树。具体来说,我不明白为什么会有这一行:
t[v]=min(t[2*v],t[2*v+1])
我很确定线段树中的节点存储其子节点的总和,因此如果您的实现接近正确,我认为它应该改为读取
t[v] = t[2*v] + t[2*v+1]
如果您的代码是 "correct",那么如果您甚至不存储区间总和,我会质疑您如何在 [x_i, y_i]
范围内找到最大区间总和。
迭代线段树
第三,线段树可以迭代实现。这是一个 C++ 教程:http://codeforces.com/blog/entry/18051。
线段树的速度应该不够快...
最后,我不明白线段树将如何帮助您解决这个问题。线段树可让您查询 log(n)
范围内的总和。此问题要求任何范围的最大可能总和。我还没有听说过允许 "range minimum query" 或 "range maximum query."
的线段树
对于 1 个查询,一个简单的解决方案是 O(n^3)(尝试所有 n^2 个可能的起点和终点并计算 O(n) 操作中的总和)。而且,如果您使用线段树,您可以获得 O(log(n)) 而不是 O(n) 的总和。这只会使您加速到 O(n^2 log(n)),这不适用于 N = 50000。
替代算法
我认为你应该看看这个,它在 O(n) 中运行每个查询:http://www.geeksforgeeks.org/largest-sum-contiguous-subarray/。将它写在 C/C++ 中并像一位评论者建议的那样高效地使用 IO。
我正在解决 this problem using segment tree 但我遇到时间限制错误。
下面是我的范围最小查询的原始代码,通过在我的代码中将 min
更改为 max
可以解决上述问题。我不知道如何提高代码的性能。你能帮我解决它的性能问题吗?
t = [None] * 2 * 7 # n is length of list
def build(a, v, start, end):
'''
A recursive function that constructs Segment Tree for list a.
v is the starting node
start and end are the index of array
'''
n = len(a)
if start == end:
t[v] = a[start]
else:
mid = (start + end) / 2
build(a, v * 2, start, mid) # v*2 is left child of parent v
# v*2+1 is the right child of parent v
build(a, v * 2 + 1, mid + 1, end)
t[v] = min(t[2 * v], t[2 * v + 1])
return t
print build([18, 17, 13, 19, 15, 11, 20], 1, 0, 6)
inf = 10**9 + 7
def range_minimum_query(node, segx, segy, qx, qy):
'''
returns the minimum number in range(qx,qy)
segx and segy represent the segment index
'''
if qx > segy or qy < segx: # query out of range
return inf
elif segx >= qx and segy <= qy: # query range inside segment range
return t[node]
else:
return min(range_minimum_query(node * 2, segx, (segx + segy) / 2, qx, qy), range_minimum_query(node * 2 + 1, ((segx + segy) / 2) + 1, segy, qx, qy))
print range_minimum_query(1, 1, 7, 1, 3)
# returns 13
这可以迭代实现吗?
您可以尝试使用生成器,因为您可以绕过很多限制。但是您没有提供清楚显示您的性能问题的数据集 - 您能提供有问题的数据集吗?
这里你可以试试:
t=[None]*2*7
inf=10**9+7
def build_generator(a, v, start, end):
n = len(a)
if start == end:
t[v] = a[start]
else:
mid = (start + end) / 2
next(build_generator(a, v * 2, start, mid))
next(build_generator(a, v * 2 + 1, mid + 1, end))
t[v] = min(t[2 * v], t[2 * v + 1])
yield t
def range_minimum_query_generator(node,segx,segy,qx,qy):
if qx > segy or qy < segx:
yield inf
elif segx >= qx and segy <= qy:
yield t[node]
else:
min_ = min(
next(range_minimum_query_generator(node*2,segx,(segx+segy)/2,qx,qy)),
next(range_minimum_query_generator(node*2+1,((segx+segy)/2)+1,segy,qx,qy))
)
yield min_
next(build_generator([18,17,13,19,15,11,20],1,0,6))
value = next(range_minimum_query_generator(1, 1, 7, 1, 3))
print(value)
编辑
事实上,这可能无法解决您的问题。还有另一种方法可以解决任何递归限制(如 D. Beazley 在其生成器教程中所述 - https://www.youtube.com/watch?v=D1twn9kLmYg&t=9588s 时间码 2h00 左右)
语言选择
首先,如果你使用 python,你可能永远不会通过评分。如果您在这里查看所有过去解决方案的状态,http://www.spoj.com/status/GSS1/start=0,您会发现几乎所有被接受的解决方案都是用 C++ 编写的。我认为您别无选择,只能使用 C++。请注意时间限制是 0.115s-0.230s。这是一个 "only-for-C/C++" 时间限制。对于接受其他语言解决方案的问题,时间限制将是一个 "round" 数字,例如 1 秒。 Python 在这种类型的环境中比 C++ 慢 2-4 倍。
线段树实现问题...?
其次,我不确定您的代码是否真的在构建线段树。具体来说,我不明白为什么会有这一行:
t[v]=min(t[2*v],t[2*v+1])
我很确定线段树中的节点存储其子节点的总和,因此如果您的实现接近正确,我认为它应该改为读取
t[v] = t[2*v] + t[2*v+1]
如果您的代码是 "correct",那么如果您甚至不存储区间总和,我会质疑您如何在 [x_i, y_i]
范围内找到最大区间总和。
迭代线段树
第三,线段树可以迭代实现。这是一个 C++ 教程:http://codeforces.com/blog/entry/18051。
线段树的速度应该不够快...
最后,我不明白线段树将如何帮助您解决这个问题。线段树可让您查询 log(n)
范围内的总和。此问题要求任何范围的最大可能总和。我还没有听说过允许 "range minimum query" 或 "range maximum query."
对于 1 个查询,一个简单的解决方案是 O(n^3)(尝试所有 n^2 个可能的起点和终点并计算 O(n) 操作中的总和)。而且,如果您使用线段树,您可以获得 O(log(n)) 而不是 O(n) 的总和。这只会使您加速到 O(n^2 log(n)),这不适用于 N = 50000。
替代算法
我认为你应该看看这个,它在 O(n) 中运行每个查询:http://www.geeksforgeeks.org/largest-sum-contiguous-subarray/。将它写在 C/C++ 中并像一位评论者建议的那样高效地使用 IO。