Tarjan 算法的非递归版本

Non-recursive version of Tarjan's algorithm

我有以下 Tarjan 算法的(递归)实现来查找图中的强连通分量,它工作正常:

public class StronglyConnectedComponents
{
    public static List<List<int>> Search(Graph graph)
    {
        StronglyConnectedComponents scc = new StronglyConnectedComponents();
        return scc.Tarjan(graph);
    }

    private int preCount;
    private int[] low;
    private bool[] visited;
    private Graph graph;
    private List<List<int>> stronglyConnectedComponents = new List<List<int>>();
    private Stack<int> stack = new Stack<int>();

    public List<List<int>> Tarjan(Graph graph)
    {
        this.graph = graph;
        low = new int[graph.VertexCount];
        visited = new bool[graph.VertexCount];

        for (int v = 0; v < graph.VertexCount; v++) if (!visited[v]) DFS(v);

        return stronglyConnectedComponents;
    }

    public void DFS(int v)
    {
        low[v] = preCount++;
        visited[v] = true;
        stack.Push(v);
        int min = low[v];
        int edgeCount = graph.OutgoingEdgeCount(v);
        for (int i = 0; i < edgeCount; i++)
        {
            var edge = graph.OutgoingEdge(v, i);
            int target = edge.Target;

            if (!visited[target]) DFS(target);
            if (low[target] < min) min = low[target];
        }

        if (min < low[v])
        {
            low[v] = min;
            return;
        }

        List<int> component = new List<int>();

        int w;
        do
        {
            w = stack.Pop();
            component.Add(w);
            low[w] = graph.VertexCount;
        } while (w != v);
        stronglyConnectedComponents.Add(component);
    }
}

但是在大图上,递归版本显然会抛出 WhosebugException。因此我想让算法成为非递归的。

我尝试用以下(非递归)函数替换函数 DFS,但该算法不再起作用。有人可以帮忙吗?

private void DFS2(int vertex)
{
    bool[] visited = new bool[graph.VertexCount];
    Stack<int> stack = new Stack<int>();
    stack.Push(vertex);
    int min = low[vertex];

    while (stack.Count > 0)
    {
        int v = stack.Pop();
        if (visited[v]) continue;
        visited[v] = true;

        int edgeCount = graph.OutgoingEdgeCount(v);
        for (int i = 0; i < edgeCount; i++)
        {
            int target = graph.OutgoingEdge(v, i).Target;
            stack.Push(target);
            if (low[target] < min) min = low[target];
        }
    }

    if (min < low[vertex])
    {
        low[vertex] = min;
        return;
    }

    List<int> component = new List<int>();

    int w;
    do
    {
        w = stack.Pop();
        component.Add(w);
        low[w] = graph.VertexCount;
    } while (w != vertex);
    stronglyConnectedComponents.Add(component);
}

以下代码显示测试:

public void CanFindStronglyConnectedComponents()
{
    Graph graph = new Graph(8);
    graph.AddEdge(0, 1);
    graph.AddEdge(1, 2);
    graph.AddEdge(2, 3);
    graph.AddEdge(3, 2);
    graph.AddEdge(3, 7);
    graph.AddEdge(7, 3);
    graph.AddEdge(2, 6);
    graph.AddEdge(7, 6);
    graph.AddEdge(5, 6);
    graph.AddEdge(6, 5);
    graph.AddEdge(1, 5);
    graph.AddEdge(4, 5);
    graph.AddEdge(4, 0);
    graph.AddEdge(1, 4);

    var scc = StronglyConnectedComponents.Search(graph);
    Assert.AreEqual(3, scc.Count);
    Assert.IsTrue(SetsEqual(Set(5, 6), scc[0]));
    Assert.IsTrue(SetsEqual(Set(7, 3, 2), scc[1]));
    Assert.IsTrue(SetsEqual(Set(4, 1, 0), scc[2]));
}

private IEnumerable<int> Set(params int[] set) => set;

private bool SetsEqual(IEnumerable<int> set1, IEnumerable<int> set2)
{
    if (set1.Count() != set2.Count()) return false;
    return set1.Intersect(set2).Count() == set1.Count();
}

这里直接非递归翻译一下原来的递归实现(假设是正确的):

public static List<List<int>> Search(Graph graph)
{
    var stronglyConnectedComponents = new List<List<int>>();

    int preCount = 0;
    var low = new int[graph.VertexCount];
    var visited = new bool[graph.VertexCount];
    var stack = new Stack<int>();

    var minStack = new Stack<int>();
    var enumeratorStack = new Stack<IEnumerator<int>>();
    var enumerator = Enumerable.Range(0, graph.VertexCount).GetEnumerator();
    while (true)
    {
        if (enumerator.MoveNext())
        {
            int v = enumerator.Current;
            if (!visited[v])
            {
                low[v] = preCount++;
                visited[v] = true;
                stack.Push(v);
                int min = low[v];
                // Level down
                minStack.Push(min);
                enumeratorStack.Push(enumerator);
                enumerator = Enumerable.Range(0, graph.OutgoingEdgeCount(v))
                    .Select(i => graph.OutgoingEdge(v, i).Target)
                    .GetEnumerator();
            }
            else if (minStack.Count > 0)
            {
                int min = minStack.Pop();
                if (low[v] < min) min = low[v];
                minStack.Push(min);
            }
        }
        else
        {
            // Level up
            if (enumeratorStack.Count == 0) break;

            enumerator = enumeratorStack.Pop();
            int v = enumerator.Current;
            int min = minStack.Pop();

            if (min < low[v])
            {
                low[v] = min;
            }
            else
            {
                List<int> component = new List<int>();

                int w;
                do
                {
                    w = stack.Pop();
                    component.Add(w);
                    low[w] = graph.VertexCount;
                } while (w != v);
                stronglyConnectedComponents.Add(component);
            }

            if (minStack.Count > 0)
            {
                min = minStack.Pop();
                if (low[v] < min) min = low[v];
                minStack.Push(min);
            }
        }
    }
    return stronglyConnectedComponents;
}

像往常一样,对于这种直接翻译,您需要一个显式堆栈来存储递归调用 "returning" 之后需要恢复的状态。在这种情况下,它是级别顶点枚举器和 min 变量。

请注意,现有的 stack 变量不能使用,因为当处理顶点被推到那里时,它并不总是在退出时弹出(递归实现中的 return 行),这是一个该算法的具体要求。

以下是我必须为 Codeforces 427C 实施的 Python 版本,因为它们不支持增加 Python 堆栈大小。

该代码使用一个额外的调用堆栈,其中包含指向当前节点以及要访问的下一个子节点的指针。

实际算法紧跟pseudocode on Wikipedia

N = # number of vertices
es = # list of edges, [(0,1), (2,4), ...]

class Node:
    def __init__(self, name):
        self.name = name
        self.index = None
        self.lowlink = None
        self.adj = []
        self.on_stack = False

vs = [Node(i) for i in range(N)]
for v, w in es:
    vs[v].adj.append(vs[w])

i = 0
stack = []
call_stack = []
comps = []
for v in vs:
    if v.index is None:
        call_stack.append((v,0))
        while call_stack:
            v, pi = call_stack.pop()
            # If this is first time we see v
            if pi == 0:
                v.index = i
                v.lowlink = i
                i += 1
                stack.append(v)
                v.on_stack = True
            # If we just recursed on something
            if pi > 0:
                prev = v.adj[pi-1]
                v.lowlink = min(v.lowlink, prev.lowlink)
            # Find the next thing to recurse on
            while pi < len(v.adj) and v.adj[pi].index is not None:
                w = v.adj[pi]
                if w.on_stack:
                    v.lowlink = min(v.lowlink, w.index)
                pi += 1
            # If we found something with index=None, recurse
            if pi < len(v.adj):
                w = v.adj[pi]
                call_stack.append((v,pi+1))
                call_stack.append((w,0))
                continue
            # If v is the root of a connected component
            if v.lowlink == v.index:
                comp = []
                while True:
                    w = stack.pop()
                    w.on_stack = False
                    comp.append(w.name)
                    if w is v:
                        break
                comps.append(comp)

或者,在 "simplified" Tarjan's algorithm 之后,我们可以进行以下翻译:

def scc2(graph):
    result = []
    stack = []
    low = {}
    call_stack = []
    for v in graph:
        call_stack.append((v, 0, len(low)))
        while call_stack:
            v, pi, num = call_stack.pop()
            if pi == 0:
                if v in low: continue
                low[v] = num
                stack.append(v)
            if pi > 0:
                low[v] = min(low[v], low[graph[v][pi-1]])
            if pi < len(graph[v]):
                call_stack.append((v, pi+1, num))
                call_stack.append((graph[v][pi], 0, len(low)))
                continue
            if num == low[v]:
                comp = []
                while True:
                    comp.append(stack.pop())
                    low[comp[-1]] = len(graph)
                    if comp[-1] == v: break
                result.append(comp)
    return result