为什么在并行化我的树搜索时会得到这个输出?

Why am I getting this output when parallelising my tree search?

我有一棵二叉树,其中每个节点都是 0 或 1。从根到叶的每条路径都是一个位串。我的代码按顺序打印出所有位串,并且工作正常。但是,当我尝试将其并行化时,我得到了意外的输出。

Class节点

public class Node{
  int value;
  Node left, right;
  int depth;

  public Node(int v){
    value = v;
    left = right = null;
  }
}

Tree.java

的连续版本
import java.util.*;
import java.util.concurrent.*;

public class Tree{
  Node root;
  int levels;
  LinkedList<LinkedList<Integer>> all;

  Tree(int v){
    root = new Node(v);
    levels = 1;
    all = new LinkedList<LinkedList<Integer>>();
  }
  Tree(){
    root = null;
    levels = 0;
  }
  public static void main(String[] args){
    Tree tree = new Tree(0);
    populate(tree, tree.root, tree.levels);
    int processors = Runtime.getRuntime().availableProcessors();
    System.out.println("Available core: "+processors);
//    ForkJoinPool pool = new ForkJoinPool(processors);

    tree.printPaths(tree.root);

//    LinkedList<Integer> path = new LinkedList<Integer>();
//    PrintTask task = new PrintTask(tree.root, path, 0, tree.all);
//    pool.invoke(task);
//    for (int i=0; i < tree.all.size(); i++){
//      System.out.println(tree.all.get(i));
//    }

  }

  public static void populate(Tree t, Node n, int levels){
    levels++;
    if(levels >6){
      n.left = null;
      n.right = null;
    }
    else{
      t.levels = levels;
      n.left = new Node(0);
      n.right = new Node(1);
      populate(t, n.left, levels);
      populate(t, n.right, levels);
    }
  }

  public void printPaths(Node node)
   {
       LinkedList<Integer> path = new LinkedList<Integer>();
       printPathsRecur(node, path, 0);
//       System.out.println("Inside ForkJoin:  "+pool.invoke(new PrintTask(node, path, 0)));
   }

  LinkedList<LinkedList<Integer>> printPathsRecur(Node node, LinkedList<Integer> path, int pathLen)
    {
        if (node == null)
            return null;

        // append this node to the path array
        path.add(node.value);
        path.set(pathLen, node.value);
        pathLen++;

        // it's a leaf, so print the path that led to here
        if (node.left == null && node.right == null){
            printArray(path, pathLen);
            LinkedList<Integer> temp = new LinkedList<Integer>();
            for (int i = 0; i < pathLen; i++){
                temp.add(path.get(i));
            }
            all.add(temp);
        }
        else
        {
            printPathsRecur(node.left, path, pathLen);
            printPathsRecur(node.right, path, pathLen);
        }
        return all;
    }

    // Utility function that prints out an array on a line.
    void printArray(LinkedList<Integer> l, int len){
        for (int i = 0; i < len; i++){
            System.out.print(l.get(i) + " ");
        }
        System.out.println("");
    }
}

这会产生预期的输出:

0 0 0 0 0 0
0 0 0 0 0 1
0 0 0 0 1 0
0 0 0 0 1 1
...

然后我并行化了Tree.java:

import java.util.*;
import java.util.concurrent.*;

public class Tree{
  Node root;
  int levels;
  LinkedList<LinkedList<Integer>> all;

  Tree(int v){
    root = new Node(v);
    levels = 1;
    all = new LinkedList<LinkedList<Integer>>();
  }
  Tree(){
    root = null;
    levels = 0;
  }
  public static void main(String[] args){
    Tree tree = new Tree(0);
    populate(tree, tree.root, tree.levels);
    int processors = Runtime.getRuntime().availableProcessors();
    System.out.println("Available core: "+processors);
    ForkJoinPool pool = new ForkJoinPool(processors);

//    tree.printPaths(tree.root);

    LinkedList<Integer> path = new LinkedList<Integer>();
    PrintTask task = new PrintTask(tree.root, path, 0, tree.all);
    pool.invoke(task);
    for (int i=0; i < tree.all.size(); i++){
      System.out.println(tree.all.get(i));
    }

  }

  public static void populate(Tree t, Node n, int levels){
    levels++;
    if(levels >6){
      n.left = null;
      n.right = null;
    }
    else{
      t.levels = levels;
      n.left = new Node(0);
      n.right = new Node(1);
      populate(t, n.left, levels);
      populate(t, n.right, levels);
    }
  }
}

并添加了一个任务 class:

import java.util.concurrent.*;
import java.util.*;

class PrintTask extends RecursiveAction {
  LinkedList<Integer> path = new LinkedList<Integer>();
  Node node;
  int pathLen;
  LinkedList<LinkedList<Integer>> all = new LinkedList<LinkedList<Integer>>();

  PrintTask(Node node, LinkedList<Integer> path, int pathLen, LinkedList<LinkedList<Integer>> all){
    this.node = node;
    this.path = path;
    this.pathLen = pathLen;
    this.all = all;
  }

  protected void compute(){
    if (node == null){
      return;
    }
    path.add(pathLen, node.value);
    pathLen++;

    if(node.left == null && node.right == null){
      printArray(path, pathLen);
      LinkedList<Integer> temp = new LinkedList<Integer>();
      for (int i = 0; i < pathLen; i++){
          temp.add(path.get(i));
      }
      all.add(temp);

    }
    else{
      invokeAll(new PrintTask(node.left, path, pathLen, all), new PrintTask(node.right, path, pathLen, all));

    }
  }
  void printArray(LinkedList<Integer> l, int len){
      for (int i = 0; i < len; i++){
          System.out.print(l.get(i) + " ");
      }
      System.out.println("");
  }

}

我得到这个输出:

Available core: 8
0 0 1 0 1 1 1 0 0
0 1 1 0 1 1 1 0 1
0 0 1 1 1 0 0
1 1 1 1 0 1
1 0 1 1 0 1 1 1 0 0 1 1 0 0 0 1 1 1 0 1
1 1 1 1 0
0 1
...

[0, 1, 1, 0, 1, 1]
[0, 1, 1, 0, 0, 0]
[0, 1, 1, 0, 0, 1]
[0, 1, 1, 1, 0, 0]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 0, 1]
[0, 1, 1, 1, 1, 0]
[0, 1, 1, 1, 0, 0]
...

因此,在动态打印路径时,它似乎与每条路径由 6 位组成的预期输出有很大不同。在这个版本中,我将所有路径存储在一个列表列表中,并在最后打印列表。它包含一些看起来正确的位串,但问题是它不是全部。它只输出以011开头的位串。

并行实现的问题是由于以下代码行造成的。

invokeAll(new PrintTask(node.left, path, pathLen, all), new PrintTask(node.right, path, pathLen, all));

invokeAll 将并行执行任务。这将导致 2 个问题。

  • 不保证左节点会在右节点之前执行
  • 竞争条件可能发生在所有任务共享的 pathpathLen 变量中。

更正它的最简单方法是依次调用左右任务。如下所示:

new PrintTask(node.left, path, pathLen, all).invoke();
new PrintTask(node.right, path, pathLen, all).invoke();

但是这样做,你失去了并行处理的好处,它和顺序执行一样好。


为了保证正确性和并行性,将做以下改动

  • all 的类型从 LinkedList<LinkedList> 更改为 LinkedList[]。我们将数组的大小设置为 2 ^ (levels - 1) 以容纳树中的所有节点。
  • 此外,我们将引入一个insertIndex变量,以便叶节点将列表插入到结果数组中的正确索引处。我们将在每个级别左移此 insertIndex,对于右树,我们也将其递增 1。
  • 我们将在每个级别创建 2 个新链表以避免竞争条件。

修改后的打印任务:

class PrintTask extends RecursiveAction {
    LinkedList<Integer> path;
    Node node;
    LinkedList[] all;
    int insertIndex;

    PrintTask(Node node, LinkedList<Integer> path, LinkedList[] all, int insertIndex) {
        this.node = node;
        this.path = path;
        this.all = all;
        this.insertIndex = insertIndex;
    }

    protected void compute() {
        if (node == null)
            return;
        path.add(node.value);
        if (node.left == null && node.right == null)
            all[insertIndex] = path;
        else
            invokeAll(new PrintTask(node.left, new LinkedList<>(path), all, insertIndex << 1),
                    new PrintTask(node.right, new LinkedList<>(path), all, (insertIndex << 1) + 1));
    }
}

main() 变化:

...
LinkedList[] result = new LinkedList[1 << tree.levels - 1];
PrintTask task = new PrintTask(tree.root, path, result, 0);
pool.invoke(task);
for (LinkedList linkedList : result) 
   System.out.println(linkedList);
...