我如何 运行 与 Java 平行的东西?

How do I run something parallel in Java?

我正在尝试打印一个范围内所有可能的组合。例如,如果我的 lowerBound 是 3 而我的 max 是 5,我想要以下组合:(5,4 - 5,3 - 4,3)。我已经使用下面的 helper() 函数实现了它。

当然,如果我的 max 很大,这就是很多组合,这会花费很长时间。这就是为什么我要尝试实现 ForkJoinPool,以便任务 运行 并行。为此,我创建了一个新的 ForkJoinPool。然后我遍历 r 的所有可能值(其中 r 是组合中数字的数量,在上面的示例中 r=3)。对于 r 的每个值,我创建一个新的 HelperCalculator,它扩展了 RecursiveTask<Void>。在那里我递归调用 helper() 函数。每次调用它时,我都会创建一个新的 HelperCalculator 并在其上使用 .fork()

问题如下。它没有正确生成所有可能的组合。它实际上根本不生成任何组合。我试过在 calculator.fork() 之后添加 calculator.join(),但这种情况会一直持续下去,直到出现 OutOfMemory 错误。

显然我对 ForkJoinPool 有一些误解,但在尝试了几天之后我再也看不出是什么了。

我的主要功能:

            ForkJoinPool pool = (ForkJoinPool) Executors.newWorkStealingPool();
            for (int r = 1; r < 25; r++) {
                int lowerBound = 7;
                int[] data = new int[r];
                int max = 25;
                calculator = new HelperCalculator(data, 0, max, 0, s, n, lowerBound);
                pool.execute(calculator);
                calculator.join();
            }
            pool.shutdown();

HelperCalculator class:

    protected Void compute() {
        helper(data, end, start, index, s, lowerBound);
        return null;
    }

    //Generate all possible combinations
    public void helper(int[] data , int end, int start, int index,int s, int lowerBound) {
        //If the array is filled, print it
        if (index == data.length) {
                System.out.println(Arrays.toString(data));
        } else if (start >= end) {
            data[index] = start;
            if(data[0] >= lowerBound) {
                HelperCalculator calculator = new HelperCalculator(data,end, start-1, index+1, s, n, lowerBound);
                calculator.fork();
                calculators.add(calculator);
                HelperCalculator calculator2 = new HelperCalculator(data, end, start-1, index, s, n, lowerBound);
                calculator2.fork();
                calculators.add(calculator2);
            }
        }

如何使用 ForkJoinPool 使每个 HelperCalculator 运行 并行,以便同时有 23 个 运行ning?或者我应该使用不同的解决方案?

我试过在 calculators 列表中调用 join()isDone(),但是它没有等待它正确完成,程序就退出了。

因为有人不懂算法,这里是:

    public static void main(String[] args) {
            for(int r = 3; r > 0; r--) {
                int[] data = new int[r];
                helper(data, 0, 2, 0);
            }
    }

    public static void helper(int[] data , int end, int start, int index) {
        if (index == data.length) {
            System.out.println(Arrays.toString(data));
        } else if (start >= end) {
            data[index] = start;
                helper(data, end, start - 1, index + 1);
                helper(data, end, start - 1, index);
            }
        }
    }

这个输出是:

[2, 1, 0]
[2, 1]
[2, 0]
[1, 0]
[2]
[1]
[0]

您分叉的一些任务尝试使用相同的数组来评估不同的组合。您可以通过为每个任务创建一个不同的数组或将并行度限制在那些已经有自己的数组的任务来解决这个问题,即那些具有不同长度的任务。

但还有另一种可能;根本不要使用数组。您可以将组合存储到 int 值中,因为每个 int 值都是位的组合。这不仅节省了大量内存,而且您还可以通过递增值轻松地遍历所有可能的组合,因为遍历所有 int 数字也会遍历所有可能的位组合¹。我们唯一需要实现的是通过根据位置将位解释为数字来为特定 int 值生成正确的字符串。

对于第一次尝试,我们可以采用简单的方法并使用现有的 类:

public static void main(String[] args) {
    long t0 = System.nanoTime();
    combinations(10, 25);
    long t1 = System.nanoTime();
    System.out.println((t1 - t0)/1_000_000+" ms");
    System.out.flush();
}
static void combinations(int start, int end) {
    for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
        System.out.println(
            BitSet.valueOf(new long[]{i}).stream()
                  .mapToObj(b -> String.valueOf(b + start))
                  .collect(Collectors.joining(", ", "[", "]"))
        );
    }
}

该方法使用独占结束,因此对于您的示例,您必须像 combinations(0, 3) 那样调用它,它会打印

[0]
[1]
[0, 1]
[2]
[0, 2]
[1, 2]
[0, 1, 2]
3 ms

当然,时间可能会有所不同

对于上面的 combinations(10, 25) 示例,它会打印所有组合,然后在我的机器上打印 3477 ms。这听起来像是一个优化的机会,但我们应该首先考虑哪些操作会产生哪些成本。

迭代这些组合在这里已经被简化为一个微不足道的操作。创建字符串的成本要高一个数量级。但这与实际打印相比仍然不算什么,实际打印包括向操作系统传输数据,并且根据系统的不同,实际渲染可能会增加我们的时间。由于这是在 PrintStream 内持有锁的情况下完成的,所有试图同时打印的线程都将被阻止,使其成为不可并行化的操作。

让我们通过创建一个新的 PrintStream、禁用换行符时的自动刷新并使用能够容纳整个输出的异常大的缓冲区来确定成本的一部分:

public static void main(String[] args) {
    System.setOut(new PrintStream(
        new BufferedOutputStream(new FileOutputStream(FileDescriptor.out),1<<20),false));
    long t0 = System.nanoTime();
    combinations(10, 25);
    long t1 = System.nanoTime();
    System.out.flush();
    long t2 = System.nanoTime();
    System.out.println((t1 - t0)/1_000_000+" ms");
    System.out.println((t2 - t0)/1_000_000+" ms");
    System.out.flush();
}
static void combinations(int start, int end) {
    for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
        System.out.println(
            BitSet.valueOf(new long[]{i}).stream()
                  .mapToObj(b -> String.valueOf(b + start))
                  .collect(Collectors.joining(", ", "[", "]"))
        );
    }
}

在我的机器上,它打印的顺序是

93 ms
3340 ms

显示代码在不可并行化打印上花费了超过三秒,而在计算上只花费了大约 100 毫秒。为了完整起见,以下代码针对 String 代降低了一个级别:

static void combinations(int start, int end) {
    for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
        System.out.println(bits(i, start));
    }
}
static String bits(int bits, int offset) {
    StringBuilder sb = new StringBuilder().append('[');
    for(;;) {
        int bit = Integer.lowestOneBit(bits), num = Integer.numberOfTrailingZeros(bit);
        sb.append(num + offset);
        bits -= bit;
        if(bits == 0) break;
        sb.append(", ");
    }
    return sb.append(']').toString();
}

这将我机器上的计算时间减半,同时对总时间没有明显影响,现在应该不足为奇了。


但出于教育目的,忽略潜在加速的缺乏,让我们讨论如何并行化此操作。

顺序代码确实已经将任务转化为一种形式,该形式归结为从起始值到结束值的迭代。现在,我们将此代码重写为 ForkJoinTask(或合适的子类),它表示具有开始值和结束值的迭代。然后,我们添加了将这个操作一分为二的能力,通过在中间分割范围,所以我们得到两个任务在范围的每一半上迭代。这可以重复,直到我们决定有足够的潜在并行作业并在本地执行当前迭代。在本地处理之后,我们必须等待我们拆分的任何任务完成,以确保根任务的完成意味着所有子任务的完成。

public class Combinations extends RecursiveAction {
    public static void main(String[] args) {
        System.setOut(new PrintStream(new BufferedOutputStream(
            new FileOutputStream(FileDescriptor.out),1<<20),false));
        ForkJoinPool pool = (ForkJoinPool) Executors.newWorkStealingPool();
        long t0 = System.nanoTime();
        Combinations job = Combinations.get(10, 25);
        pool.execute(job);
        job.join();
        long t1 = System.nanoTime();
        System.out.flush();
        long t2 = System.nanoTime();
        System.out.println((t1 - t0)/1_000_000+" ms");
        System.out.println((t2 - t0)/1_000_000+" ms");
        System.out.flush();
    }

    public static Combinations get(int min, int max) {
        return new Combinations(min, 1, (1 << (max - min)) - 1);
    }

    final int offset, from;
    int to;

    private Combinations(int offset, int from, int to) {
        this.offset = offset;
        this.from = from;
        this.to = to;
    }

    @Override
    protected void compute() {
        ArrayDeque<Combinations> spawned = new ArrayDeque<>();
        while(getSurplusQueuedTaskCount() < 2) {
            int middle = (from + to) >>> 1;
            if(middle == from) break;
            Combinations forked = new Combinations(offset, middle, to);
            forked.fork();
            spawned.addLast(forked);
            to = middle - 1;
        }
        performLocal();
        for(;;) {
            Combinations forked = spawned.pollLast();
            if(forked == null) break;
            if(forked.tryUnfork()) forked.performLocal(); else forked.join();
        }
    }

    private void performLocal() {
        for(int i = from, stop = to; i <= stop; i++) {
            System.out.println(bits(i, offset));
        }
    }

    static String bits(int bits, int offset) {
        StringBuilder sb = new StringBuilder().append('[');
        for(;;) {
            int bit=Integer.lowestOneBit(bits), num=Integer.numberOfTrailingZeros(bit);
            sb.append(num + offset);
            bits -= bit;
            if(bits == 0) break;
            sb.append(", ");
        }
        return sb.append(']').toString();
    }
}

getSurplusQueuedTaskCount() 为我们提供了有关工作线程饱和度的提示,换句话说,分叉更多作业是否有益。将返回的数字与通常是一个小数字的阈值进行比较,工作越多样化,因此,预期的工作量,阈值应该越高,以便在工作比其他工作更早完成时允许更多的工作窃取。在我们的案例中,工作负载预计会非常平衡。

拆分有两种方式。示例通常会创建两个或多个分叉子任务,然后加入它们。这可能会导致大量任务只是在等待其他任务。另一种方法是分叉一个子任务并改变当前任务,以代表另一个。在这里,分叉任务代表 [middle, to] 范围,而当前任务被修改为代表 [from, middle] 范围。

分叉足够多的任务后,剩余范围在当前线程本地处理。然后,该任务将等待所有分叉的子任务,并进行一项优化:它将 try to unfork 子任务,如果没有其他工作线程窃取它们,则在本地处理它们。

这很顺利,但不幸的是,正如预期的那样,它并没有加速操作,因为最昂贵的部分是打印。


¹ 使用 int 来表示所有组合将支持的范围长度减少到 31,但请记住,这样的范围长度意味着 2³¹ - 1 组合,这需要迭代很多.如果这仍然是一个限制,您可以更改代码以使用 long 代替。当时支持的 63 范围长度,换句话说 2⁶³ - 1 组合,足以让计算机忙到宇宙尽头。