从编译时依赖图 (DAG) 构建异步“未来”回调链

Building asynchronous `future` callback chain from compile-time dependency graph (DAG)

我有一个编译时 directed acyclic graph 的异步任务。 DAG 显示了任务之间的依赖关系:通过分析它,可以了解哪些任务可以 运行 并行 (在单独的线程中) 以及哪些任务需要等待其他任务在开始之前需要完成的任务 (dependencies).

我想使用 boost::future.then(...)when_all(...) 延续辅助函数从 DAG 生成回调链。这一代的结果将是一个函数,调用时将启动回调链并执行 DAG 描述的任务,运行并行执行尽可能多的任务。

但是,我在寻找适用于所有情况的通用算法时遇到了麻烦。

为了让问题更容易理解,我画了几张图。这是一个图例,可以告诉你图中符号的含义:

让我们从一个简单的线性 DAG 开始:

此依赖关系图由三个任务组成 ABCC 取决于 BB 取决于 A。这里没有并行的可能性——生成算法将构建类似于这样的东西:

boost::future<void> A, B, C, end;

A.then([]
    {
        B.then([]
            {
                C.get();
                end.get();
            });
    });

(请注意,所有代码示例并非 100% 有效 - 我忽略了移动语义、转发和 lambda 捕获。)

有很多方法可以解决这个线性 DAG:无论是从末尾开始还是从头开始,构建正确的回调链都是微不足道的。

引入 forks and joins 后,事情开始变得更加复杂。

这是一个带有 fork/join 的 DAG:

很难想到匹配这个DAG的回调链。如果我尝试向后工作,从最后开始,我的推理如下:

一个可能的链看起来像这样:

boost::future<void> A, B, C, D, end;

A.then([]
    {
        boost::when_all(B, C.then([]
                               {
                                   D.get();
                               }))
            .then([]
                {
                    end.get();
                });
    });

手写这条链我觉得很难,而且我也怀疑它的正确性。我想不出一种通用的方法来实现可以生成此算法的算法 - 由于 when_all 需要将其参数移入其中,因此还存在其他困难。

我们再看一个更复杂的例子:

这里我们要尽可能地利用并行性。考虑任务 EE 可以 运行 与任何 [B, C, D].

并行

这是一个可能的回调链:

boost::future<void> A, B, C, D, E, F, end;

A.then([]
    {
        boost::when_all(boost::when_all(B, C).then([]
                            {
                                D.get();
                            }),
            E)
            .then([]
                {
                    F.then([]
                        {
                            end.get();
                        });
                });
    });

我尝试通过多种方式提出通用算法:

显然 "breadth-first traversal" 方法在这里效果不佳。从我手写的代码示例来看,算法似乎需要知道分叉和连接,并且需要能够正确混合 .then(...)when_all(...) continuations.

这是我的最后一个问题:


编辑 1:

Here's an additional approach我正在尝试探索。

想法是从 DAG 生成一个 ([dependencies...] -> [dependents...]) 地图数据结构,并从该地图生成回调链。

如果len(dependencies...) > 1,那么value是一个join节点。

如果len(dependents...) > 1,那么key是一个fork节点。

map中的所有键值对都可以表示为when_all(keys...).then(values...)个continuations

困难的部分是找出 "expand" (想想类似于解析器的东西) 节点的正确顺序以及如何连接 fork/join一起延续。

考虑以下由图像 4.

生成的地图
depenendencies  |  dependents
----------------|-------------
[F]             :  [end]
[D, E]          :  [F]
[B, C]          :  [D]
[A]             :  [E, C, B]
[begin]         :  [A]

通过应用某种类似解析器的 reductions/passes,我们可以获得 "clean" 回调链:

// First pass:
// Convert everything to `when_all(...).then(...)` notation
when_all(F).then(end)
when_all(D, E).then(F)
when_all(B, C).then(D)
when_all(A).then(E, C, B)
when_all(begin).then(A)

// Second pass:
// Solve linear (trivial) transformations
when_all(D, E).then(
    when_all(F).then(end)
)
when_all(B, C).then(D)
when_all(
    when_all(begin).then(A)
).then(E, C, B)

// Third pass:
// Solve fork/join transformations
when_all(
    when_all(begin).then(A)
).then(
    when_all(
        E, 
        when_all(B, C).then(D)
    ).then(
        when_all(F).then(end)
    )   
)

第三遍是最重要的一遍,也是看起来很难设计算法的一遍。

请注意如何在 [E, C, B] 列表中找到 [B, C],以及如何在 [D, E] 依赖项列表中将 D 解释为when_all(B, C).then(D) 并在 when_all(E, when_all(B, C).then(D)).

中与 E 链接在一起

也许整个问题可以简化为:

给定一个包含 [dependencies...] -> [dependents...] 个键值对的映射,如何实现将这些键值对转换为 when_all(...)/.then(...) 延续链的算法?

编辑 2:

这是我为上述方法想出的一些 pseudocode。它似乎适用于我尝试过的 DAG,但我需要花更多时间在它上面 "mentally" 使用其他更棘手的 DAG 配置对其进行测试。

如果您停止以显式依赖项的形式考虑它并组织 DAG,这似乎相当容易。每个任务都可以按如下方式组织(C# 因为解释这个想法要简单得多):

class MyTask
{
    // a list of all tasks that depend on this to be finished
    private readonly ICollection<MyTask> _dependenants;
    // number of not finished dependencies of this task
    private int _nrDependencies;

    public int NrDependencies
    {
        get { return _nrDependencies; }
        private set { _nrDependencies = value; }
    }
}

如果您以这种形式组织了 DAG,那么问题实际上非常简单:可以执行 _nrDependencies == 0 中的每个任务。所以我们需要一个类似于下面的 运行 方法:

public async Task RunTask()
{
    // Execute actual code of the task.
    var tasks = new List<Task>();
    foreach (var dependent in _dependenants)
    {
        if (Interlocked.Decrement(ref dependent._nrDependencies) == 0)
        {
            tasks.Add(Task.Run(() => dependent.RunTask()));
        }
    }
    await Task.WhenAll(tasks);
}

基本上,一旦我们的任务完成,我们就会遍历所有依赖项并执行所有没有未完成依赖项的依赖项。

要开始整个事情,您唯一需要做的就是调用 RunTask() 开始所有从零依赖项开始的任务(至少其中一个必须存在,因为我们有 DAG) .一旦所有这些任务完成,我们就知道整个 DAG 已经执行。

最简单的方法是从图的entry节点开始,就像手写代码一样。为了解决 join 问题,你不能使用递归解决方案,你需要有一个 topological ordering 你的图,然后 然后 构建图到订单。

这可以保证当您构建一个节点时,它的所有前任节点都已经创建。

为了实现这个目标,我们可以使用 reverse postordering 的 DFS。

一旦你进行了拓扑排序,你就可以忘记原来的节点ID,而是用列表中的编号来引用节点。为此,您需要创建一个编译时映射,允许使用拓扑排序中的节点索引而不是节点原始节点索引来检索节点前驱。


编辑: 跟进如何在编译时实现拓扑排序,我重构了这个答案。

为了在同一页面上,我假设您的图表如下所示:

struct mygraph
{
     template<int Id>
     static constexpr auto successors(node_id<Id>) ->
        list< node_id<> ... >; //List of successors for the input node

     template<int Id>
     static constexpr auto predecessors(node_id<Id>) ->
        list< node_id<> ... >; //List of predecessors for the input node

     //Get the task associated with the given node.
     template<int Id>
     static constexpr auto task(node_id<Id>);

     using entry_node = node_id<0>;
};

第一步:拓扑排序

您需要的基本要素是节点 ID 的编译时集。在 TMP 中,集合也是列表,因为在 set<Ids...>Ids 的顺序很重要。这意味着您可以使用相同的数据结构来编码有关节点是否被访问的信息以及同时产生的排序。

/** Topological sort using DFS with reverse-postordering **/
template<class Graph>
struct topological_sort
{
private:
    struct visit;

    // If we reach a node that we already visited, do nothing.
    template<int Id, int ... Is>
    static constexpr auto visit_impl( node_id<Id>,
                                      set<Is...> visited,
                                      std::true_type )
    {
        return visited;
    }

    // This overload kicks in when node has not been visited yet.
    template<int Id, int ... Is>
    static constexpr auto visit_impl( node_id<Id> node,
                                      set<Is...> visited,
                                      std::false_type )
    {
        // Get the list of successors for the current node
        constexpr auto succ = Graph::successors(node);

        // Reverse postordering: we call insert *after* visiting the successors
        // This will call "visit" on each successor, updating the
        // visited set after each step.
        // Then we insert the current node in the set.
        // Notice that if the graph is cyclic we end up in an infinite
        // recursion here.
        return fold( succ,
                     visited,
                     visit() ).insert(node);

        // Conventional DFS would be:
        // return fold( succ, visited.insert(node), visit() );
    }

    struct visit
    {
        // Dispatch to visit_impl depending on the result of visited.contains(node)
        // Note that "contains" returns a type convertible to
        // integral_constant<bool,x>
        template<int Id, int ... Is>
        constexpr auto operator()( set<Is...> visited, node_id<Id> node ) const
        {
            return visit_impl(node, visited, visited.contains(node) );
        }
    };

public:
    template<int StartNodeId>
    static constexpr auto compute( node_id<StartNodeId> node )
    {
        // Start visiting from the entry node
        // The set of visited nodes is initially empty.
        // "as_list" converts set<Is ... > to list< node_id<Is> ... >.
        return reverse( visit()( set<>{}, node ).as_list() );
    }
};

此算法与上一个示例中的图形(假设 A = node_id<0>B = node_id<1> 等)生成 list<A,B,C,D,E,F>.

第 2 步:图表映射

这只是一个适配器,它根据给定的顺序修改图中每个节点的 ID。因此,假设前面的步骤返回 list<C,D,A,B>,此 graph_map 会将索引 0 映射到 C,将索引 1 映射到 D,等等

template<class Graph, class List>
class graph_map
{   
    // Convert a node_id from underlying graph.
    // Use a function-object so that it can be passed to algorithms.
    struct from_underlying
    { 
        template<int I>
        constexpr auto operator()(node_id<I> id) 
        { return node_id< find(id, List{}) >{}; }
    };

    struct to_underlying
    { 
        template<int I>
        constexpr auto operator()(node_id<I> id) 
        { return get<I>(List{}); }
    };

public:        
    template<int Id>
    static constexpr auto successors( node_id<Id> id )
    {
        constexpr auto orig_id = to_underlying()(id);
        constexpr auto orig_succ = Graph::successors( orig_id );
        return transform( orig_succ, from_underlying() );
    }

    template<int Id>
    static constexpr auto predecessors( node_id<Id> id )
    {
        constexpr auto orig_id = to_underlying()(id);
        constexpr auto orig_succ = Graph::predecessors( orig_id );
        return transform( orig_succ, from_underlying() );
    }

    template<int Id>
    static constexpr auto task( node_id<Id> id )
    {
        return Graph::task( to_underlying()(id) );
    }

    using entry_node = decltype( from_underlying()( typename Graph::entry_node{} ) );
};

第 3 步:assemble 结果

我们现在可以按顺序遍历每个节点 ID。由于我们构建图形地图的方式,我们知道 I 的所有前辈都有一个小于 I 的节点 ID,对于每个可能的节点 I.

// Returns a tuple<> of futures
template<class GraphMap, class ... Ts>
auto make_cont( std::tuple< future<Ts> ... > && pred )
{
     // The next node to work with is N:
     constexpr auto current_node = node_id< sizeof ... (Ts) >();

     // Get a list of all the predecessors for the current node.
     auto indices = GraphMap::predecessors( current_node );

     // "select" is some magic function that takes a tuple of Ts
     // and an index_sequence, and returns a tuple of references to the elements 
     // from the input tuple that are in the indices list. 
     auto futures = select( pred, indices );

     // Assuming you have an overload of when_all that takes a tuple,
     // otherwise use C++17 apply.
     auto join = when_all( futures );

     // Note: when_all with an empty parameter list returns a future< tuple<> >,
     // which is always ready.
     // In general this has to be a shared_future, but you can avoid that
     // by checking if this node has only one successor.
     auto next = join.then( GraphMap::task( current_node ) ).share();

     // Return a new tuple of futures, pushing the new future at the back.
     return std::tuple_cat( std::move(pred),
                            std::make_tuple(std::move(next)) );         
}


// Returns a tuple of futures, you can take the last element if you
// know that your DAG has only one leaf, or do some additional 
// processing to extract only the leaf nodes.
template<class Graph>
auto make_callback_chain()
{
    constexpr auto entry_node = typename Graph::entry_node{};

    constexpr auto sorted_list = 
         topological_sort<Graph>::compute( entry_node );

    using map = graph_map< Graph, decltype(sorted_list) >;

    // Note: we are not really using the "index" in the functor here, 
    // we only want to call make_cont once for each node in the graph
    return fold( sorted_list, 
                 std::make_tuple(), //Start with an empty tuple
                 []( auto && tuple, auto index )
                 {
                     return make_cont<map>(std::move(tuple));
                 } );
}

Full live demo

如果可能出现冗余依赖,请先删除它们(参见 https://mathematica.stackexchange.com/questions/33638/remove-redundant-dependencies-from-a-directed-acyclic-graph)。

然后执行以下图形转换(在合并的节点中构建子表达式)直到您下降到单个节点(类似于您计算电阻器网络的方式):

*:额外的传入或传出依赖项,具体取决于位置

(...): 单个节点中的表达式

Java 代码,包括更复杂示例的设置:

public class DirectedGraph {
  /** Set of all nodes in the graph */
  static Set<Node> allNodes = new LinkedHashSet<>();

  static class Node {
    /** Set of all preceeding nodes */
    Set<Node> prev = new LinkedHashSet<>();

    /** Set of all following nodes */
    Set<Node> next = new LinkedHashSet<>();

    String value;

    Node(String value) {
      this.value = value;
      allNodes.add(this);
    }

    void addPrev(Node other) {
      prev.add(other);
      other.next.add(this);
    }

    /** Returns one of the next nodes */
    Node anyNext() {
      return next.iterator().next();
    }

    /** Merges this node with other, then removes other */
    void merge(Node other) {
      prev.addAll(other.prev);
      next.addAll(other.next);
      for (Node on: other.next) {
        on.prev.remove(other);
        on.prev.add(this);
      }
      for (Node op: other.prev) {
        op.next.remove(other);
        op.next.add(this);
      }
      prev.remove(this);
      next.remove(this);
      allNodes.remove(other);
    }

    public String toString() {
      return value;
    }
  }

  /** 
   * Merges sequential or parallel nodes following the given node.
   * Returns true if any node was merged.
   */
  public static boolean processNode(Node node) {
    // Check if we are the start of a sequence. Merge if so.
    if (node.next.size() == 1 && node.anyNext().prev.size() == 1) {
      Node then = node.anyNext();
      node.value += " then " + then.value;
      node.merge(then);
      return true;
    }

    // See if any of the next nodes has a parallel node with
    // the same one level indirect target. 
    for (Node next : node.next) {

      // Nodes must have only one in and out connection to be merged.
      if (next.prev.size() == 1 && next.next.size() == 1) {

        // Collect all parallel nodes with only one in and out connection 
        // and the same target; the same source is implied by iterating over 
        // node.next again.
        Node target = next.anyNext().next();
        Set<Node> parallel = new LinkedHashSet<Node>();
        for (Node other: node.next) {
          if (other != next && other.prev.size() == 1
             && other.next.size() == 1 && other.anyNext() == target) {
            parallel.add(other);
          }
        }

        // If we have found any "parallel" nodes, merge them
        if (parallel.size() > 0) {
          StringBuilder sb = new StringBuilder("allNodes(");
          sb.append(next.value);
          for (Node other: parallel) {
            sb.append(", ").append(other.value);
            next.merge(other);
          }
          sb.append(")");
          next.value = sb.toString();
          return true;
        }
      }
    }
    return false;
  }

  public static void main(String[] args) {
    Node a = new Node("A");
    Node b = new Node("B");
    Node c = new Node("C");
    Node d = new Node("D");
    Node e = new Node("E");
    Node f = new Node("F");

    f.addPrev(d);
    f.addPrev(e);

    e.addPrev(a);

    d.addPrev(b);
    d.addPrev(c);

    b.addPrev(a);
    c.addPrev(a);

    boolean anyChange;
    do {
      anyChange = false;
      for (Node node: allNodes) {
        if (processNode(node)) {
          anyChange = true;
          // We need to leave the inner loop here because changes
          // invalidate the for iteration. 
          break;
        }
      }
      // We are done if we can't find any node to merge.
    } while (anyChange);

    System.out.println(allNodes.toString());
  }
}

输出:A then all(E, all(B, C) then D) then F

此图不是在 compile-time 中构建的,但我不清楚这是否是必需的。该图保存在实现为 adjacency_list<vecS, vecS, bidirectionalS> 的增强图中。一次调度将启动任务。我们只需要每个节点的 in-edges ,这样我们就知道我们在等待什么。 pre-calculated 在下面的调度程序中实例化。

我认为不需要完整的拓扑排序。

例如,如果依赖图是:

使用scheduler_driver.cpp

对于

中的联接

只需重新定义 Graph 即可定义有向边。

所以,回答你的 2 个问题:

。是的,对于 DAG。每个节点只需要唯一的直接依赖关系,可以是 pre-computed 如下。然后可以通过一次调度启动依赖链,多米诺骨牌链倒下。

。是的,请参见下面的算法(使用 C++11 线程,而不是 boost::thread)。对于分叉,通信需要 shared_future,而基于 future 的通信支持加入。

scheduler_driver.hpp:

#ifndef __SCHEDULER_DRIVER_HPP__
#define __SCHEDULER_DRIVER_HPP__

#include <iostream>
#include <ostream>
#include <iterator>
#include <vector>
#include <chrono>

#include "scheduler.h"

#endif

scheduler_driver.cpp:

#include "scheduler_driver.hpp"

enum task_nodes
  {
    task_0,
    task_1,
    task_2,
    task_3,
    task_4,
    task_5,
    task_6,
    task_7,
    task_8,
    task_9,
    N
  };

int basic_task(int a, int d)
{
  std::chrono::milliseconds sleepDuration(d);
  std::this_thread::sleep_for(sleepDuration);
  std::cout << "Result: " << a << "\n";
  return a;
}

using namespace SCHEDULER;

int main(int argc, char **argv)
{

  using F = std::function<R()>;

  Graph deps(N);
  boost::add_edge(task_0, task_1, deps);
  boost::add_edge(task_0, task_2, deps);
  boost::add_edge(task_0, task_3, deps);
  boost::add_edge(task_1, task_4, deps);
  boost::add_edge(task_1, task_5, deps);
  boost::add_edge(task_1, task_6, deps);
  boost::add_edge(task_2, task_7, deps);
  boost::add_edge(task_2, task_8, deps);
  boost::add_edge(task_2, task_9, deps);

  std::vector<F> tasks = 
    {
      std::bind(basic_task, 0, 1000),
      std::bind(basic_task, 1, 1000),
      std::bind(basic_task, 2, 1000),
      std::bind(basic_task, 3, 1000),
      std::bind(basic_task, 4, 1000),
      std::bind(basic_task, 5, 1000),
      std::bind(basic_task, 6, 1000),
      std::bind(basic_task, 7, 1000),
      std::bind(basic_task, 8, 1000),
      std::bind(basic_task, 9, 1000)
    };

  auto s = std::make_unique<scheduler<int>>(std::move(deps), std::move(tasks));
  s->doit();

  return 0;
}

scheduler.h:

#ifndef __SCHEDULER2_H__
#define __SCHEDULER2_H__

#include <iostream>
#include <vector>
#include <iterator>
#include <functional>
#include <algorithm>
#include <mutex>
#include <thread>
#include <future>
#include <boost/graph/graph_traits.hpp>
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/depth_first_search.hpp>
#include <boost/graph/visitors.hpp>

using namespace boost;

namespace SCHEDULER
{

  using Graph = adjacency_list<vecS, vecS, bidirectionalS>;
  using Edge = graph_traits<Graph>::edge_descriptor;
  using Vertex = graph_traits<Graph>::vertex_descriptor;
  using VectexCont = std::vector<Vertex>;
  using outIt = graph_traits<Graph>::out_edge_iterator;
  using inIt = graph_traits<Graph>::in_edge_iterator;

  template<typename R>
    class scheduler
    {
    public:
      using ret_type = R;
      using fun_type = std::function<R()>;
      using prom_type = std::promise<ret_type>;
      using fut_type = std::shared_future<ret_type>;

      scheduler() = default;
      scheduler(const Graph &deps_, const std::vector<fun_type> &tasks_) :
        g(deps_),
        tasks(tasks_) { init_();}
        scheduler(Graph&& deps_, std::vector<fun_type>&& tasks_) :
          g(std::move(deps_)),
          tasks(std::move(tasks_)) { init_(); }
        scheduler(const scheduler&) = delete;
        scheduler& operator=(const scheduler&) = delete;

        void doit();

    private:
        void init_();
        std::list<Vertex> get_sources(const Vertex& v);
        auto task_thread(fun_type&& f, int i);

        Graph g;
        std::vector<fun_type> tasks;
        std::vector<prom_type> prom;
        std::vector<fut_type> fut;
        std::vector<std::thread> th;
        std::vector<std::list<Vertex>> sources;

    };

  template<typename R>
    void
    scheduler<R>::init_()
    {
      int num_tasks = tasks.size();

      prom.resize(num_tasks);
      fut.resize(num_tasks);

      // Get the futures
      for(size_t i=0;
          i<num_tasks;
          ++i)
        {
          fut[i] = prom[i].get_future();
        }

      // Predetermine in_edges for faster traversal
      sources.resize(num_tasks);
      for(size_t i=0;
          i<num_tasks;
          ++i)
        {
          sources[i] = get_sources(i);
        }
    }

  template<typename R>
    std::list<Vertex>
    scheduler<R>::get_sources(const Vertex& v)
    {
      std::list<Vertex> r;
      Vertex v1;
      inIt j, j_end;
      boost::tie(j,j_end) = in_edges(v, g);
      for(;j != j_end;++j)
        {
          v1 = source(*j, g);
          r.push_back(v1);
        }
      return r;
    }

  template<typename R>
    auto
    scheduler<R>::task_thread(fun_type&& f, int i)
    {
      auto j_beg = sources[i].begin(), 
        j_end = sources[i].end();
      for(;
          j_beg != j_end;
          ++j_beg)
        {
          R val = fut[*j_beg].get();
        }

      return std::thread([this](fun_type f, int i)
                         {
                           prom[i].set_value(f());
                         },f,i);
    }

  template<typename R>
    void
    scheduler<R>::doit()
    {
      size_t num_tasks = tasks.size();
      th.resize(num_tasks);

      for(int i=0;
          i<num_tasks;
          ++i)
        {
          th[i] = task_thread(std::move(tasks[i]), i);
        }
      for_each(th.begin(), th.end(), mem_fn(&std::thread::join));
    }

} // namespace SCHEDULER

#endif

我不确定你的设置是什么以及为什么你需要构建 DAG,但我认为简单的贪心算法可能就足够了。

when (some task have finished) {
     mark output resources done;
     find all tasks that can be run;
     post them to thread pool;
}

考虑使用英特尔的 TBB Flow Graph 库。