如何正确实现不相交的集合数据结构以在 Python 中查找生成森林?
How to properly implement disjoint set data structure for finding spanning forests in Python?
最近在尝试实现googlekickstater 2019编程题的解法,按照解析说明尝试实现Round E的Cherries Mesh。
这是问题和分析的link。
https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721
这是我实现的代码:
t = int(input())
for k in range(1,t+1):
n, q = map(int,input().split())
se = list()
for _ in range(q):
a,b = map(int,input().split())
se.append((a,b))
l = [{x} for x in range(1,n+1)]
#print(se)
for s in se:
i = 0
while ({s[0]}.isdisjoint(l[i])):
i += 1
j = 0
while ({s[1]}.isdisjoint(l[j])):
j += 1
if i!=j:
l[i].update(l[j])
l.pop(j)
#print(l)
count = q+2*(len(l)-1)
print('Case #',k,': ',count,sep='')
这通过了示例用例,但未通过测试用例。据我所知,这应该是正确的。我做错了什么吗?
两期:
- 您的算法用于检查边是否链接两个不相交的集合,如果没有则加入它们,效率低下。 Disjoint-Set data structure 上的 Union-Find 算法效率更高
- 最终计数与原来的黑边数无关,因为那些黑边可能有循环,所以其中一些不应该被计算在内。而是计算总共有多少条边(不考虑颜色)。由于该解代表一棵最小生成树,因此边数为n-1。从中减去你拥有的不相交集的数量(就像你已经做的那样)。
我还建议使用有意义的变量名。代码更容易理解。单字母变量,如 t
、q
或 s
,不是很有用。
有几种方法可以实现联合查找功能。在这里我定义了一个 Node
class 它有那些方法:
# Implementation of Union-Find (Disjoint Set)
class Node:
def __init__(self):
self.parent = self
self.rank = 0
def find(self):
if self.parent.parent != self.parent:
self.parent = self.parent.find()
return self.parent
def union(self, other):
node = self.find()
other = other.find()
if node == other:
return True # was already in same set
if node.rank > other.rank:
node, other = other, node
node.parent = other
other.rank = max(other.rank, node.rank + 1)
return False # was not in same set, but now is
testcount = int(input())
for testid in range(1, testcount + 1):
nodecount, blackcount = map(int, input().split())
# use Union-Find data structure
nodes = [Node() for _ in range(nodecount)]
blackedges = []
for _ in range(blackcount):
start, end = map(int, input().split())
blackedges.append((nodes[start - 1], nodes[end - 1]))
# Start with assumption that all edges on MST are red:
sugarcount = nodecount * 2 - 2
for start, end in blackedges:
if not start.union(end): # When edge connects two disjoint sets:
sugarcount -= 1 # Use this black edge instead of red one
print('Case #{}: {}'.format(testid, sugarcount))
您得到的答案不正确,因为您计算的计数不正确。它需要 n-1
条边将 n
个节点连接成一棵树,其中 num_clusters-1
个必须是红色的。
但是如果你修复了这个问题,你的程序仍然会很慢,因为你的集合实现是不相交的。
值得庆幸的是,几乎任何编程语言都可以很容易地在单个 array/list/vector 中实现非常高效的不相交集数据结构。这是 python 中的一个不错的。我的盒子上有 python 2,所以我的打印和输入语句与你的略有不同:
# Create a disjoint set data structure, with n singletons, numbered 0 to n-1
# This is a simple array where for each item x:
# x > 0 => a set of size x, and x <= 0 => a link to -x
def ds_create(n):
return [1]*n
# Find the current root set for original singleton index
def ds_find(ds, index):
val = ds[index]
if (val > 0):
return index
root = ds_find(ds, -val)
if (val != -root):
ds[index] = -root # path compression
return root
# Merge given sets. returns False if they were already merged
def ds_union(ds, a, b):
aroot = ds_find(ds, a)
broot = ds_find(ds, b)
if aroot == broot:
return False
# union by size
if ds[aroot] >= ds[broot]:
ds[aroot] += ds[broot]
ds[broot] = -aroot
else:
ds[broot] += ds[aroot]
ds[aroot] = -broot
return True
# Count root sets
def ds_countRoots(ds):
return sum(1 for v in ds if v > 0)
#
# CherriesMesh solution
#
numTests = int(raw_input())
for testNum in range(1,numTests+1):
numNodes, numEdges = map(int,raw_input().split())
sets = ds_create(numNodes)
for _ in range(numEdges):
a,b = map(int,raw_input().split())
print a,b
ds_union(sets, a-1, b-1)
count = numNodes + ds_countRoots(sets) - 2
print 'Case #{0}: {1}'.format(testNum, count)
最近在尝试实现googlekickstater 2019编程题的解法,按照解析说明尝试实现Round E的Cherries Mesh。 这是问题和分析的link。 https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721
这是我实现的代码:
t = int(input())
for k in range(1,t+1):
n, q = map(int,input().split())
se = list()
for _ in range(q):
a,b = map(int,input().split())
se.append((a,b))
l = [{x} for x in range(1,n+1)]
#print(se)
for s in se:
i = 0
while ({s[0]}.isdisjoint(l[i])):
i += 1
j = 0
while ({s[1]}.isdisjoint(l[j])):
j += 1
if i!=j:
l[i].update(l[j])
l.pop(j)
#print(l)
count = q+2*(len(l)-1)
print('Case #',k,': ',count,sep='')
这通过了示例用例,但未通过测试用例。据我所知,这应该是正确的。我做错了什么吗?
两期:
- 您的算法用于检查边是否链接两个不相交的集合,如果没有则加入它们,效率低下。 Disjoint-Set data structure 上的 Union-Find 算法效率更高
- 最终计数与原来的黑边数无关,因为那些黑边可能有循环,所以其中一些不应该被计算在内。而是计算总共有多少条边(不考虑颜色)。由于该解代表一棵最小生成树,因此边数为n-1。从中减去你拥有的不相交集的数量(就像你已经做的那样)。
我还建议使用有意义的变量名。代码更容易理解。单字母变量,如 t
、q
或 s
,不是很有用。
有几种方法可以实现联合查找功能。在这里我定义了一个 Node
class 它有那些方法:
# Implementation of Union-Find (Disjoint Set)
class Node:
def __init__(self):
self.parent = self
self.rank = 0
def find(self):
if self.parent.parent != self.parent:
self.parent = self.parent.find()
return self.parent
def union(self, other):
node = self.find()
other = other.find()
if node == other:
return True # was already in same set
if node.rank > other.rank:
node, other = other, node
node.parent = other
other.rank = max(other.rank, node.rank + 1)
return False # was not in same set, but now is
testcount = int(input())
for testid in range(1, testcount + 1):
nodecount, blackcount = map(int, input().split())
# use Union-Find data structure
nodes = [Node() for _ in range(nodecount)]
blackedges = []
for _ in range(blackcount):
start, end = map(int, input().split())
blackedges.append((nodes[start - 1], nodes[end - 1]))
# Start with assumption that all edges on MST are red:
sugarcount = nodecount * 2 - 2
for start, end in blackedges:
if not start.union(end): # When edge connects two disjoint sets:
sugarcount -= 1 # Use this black edge instead of red one
print('Case #{}: {}'.format(testid, sugarcount))
您得到的答案不正确,因为您计算的计数不正确。它需要 n-1
条边将 n
个节点连接成一棵树,其中 num_clusters-1
个必须是红色的。
但是如果你修复了这个问题,你的程序仍然会很慢,因为你的集合实现是不相交的。
值得庆幸的是,几乎任何编程语言都可以很容易地在单个 array/list/vector 中实现非常高效的不相交集数据结构。这是 python 中的一个不错的。我的盒子上有 python 2,所以我的打印和输入语句与你的略有不同:
# Create a disjoint set data structure, with n singletons, numbered 0 to n-1
# This is a simple array where for each item x:
# x > 0 => a set of size x, and x <= 0 => a link to -x
def ds_create(n):
return [1]*n
# Find the current root set for original singleton index
def ds_find(ds, index):
val = ds[index]
if (val > 0):
return index
root = ds_find(ds, -val)
if (val != -root):
ds[index] = -root # path compression
return root
# Merge given sets. returns False if they were already merged
def ds_union(ds, a, b):
aroot = ds_find(ds, a)
broot = ds_find(ds, b)
if aroot == broot:
return False
# union by size
if ds[aroot] >= ds[broot]:
ds[aroot] += ds[broot]
ds[broot] = -aroot
else:
ds[broot] += ds[aroot]
ds[aroot] = -broot
return True
# Count root sets
def ds_countRoots(ds):
return sum(1 for v in ds if v > 0)
#
# CherriesMesh solution
#
numTests = int(raw_input())
for testNum in range(1,numTests+1):
numNodes, numEdges = map(int,raw_input().split())
sets = ds_create(numNodes)
for _ in range(numEdges):
a,b = map(int,raw_input().split())
print a,b
ds_union(sets, a-1, b-1)
count = numNodes + ds_countRoots(sets) - 2
print 'Case #{0}: {1}'.format(testNum, count)