并行 flatMap 总是顺序的

Parallel flatMap always sequential

假设我有这个代码:

 Collections.singletonList(10)
            .parallelStream() // .stream() - nothing changes
            .flatMap(x -> Stream.iterate(0, i -> i + 1)
                    .limit(x)
                    .parallel()
                    .peek(m -> {
                        System.out.println(Thread.currentThread().getName());
                    }))
            .collect(Collectors.toSet());

输出是相同的线程名称,因此这里的 parallel 没有任何好处 - 我的意思是只有一个线程完成所有工作。

flatMap里面有这个代码:

result.sequential().forEach(downstream);

我理解如果 "outer" 流是并行的(它们可能会阻塞),则强制 sequential 属性,"outer" 将不得不等待 "flatMap" 完成,反之亦然(因为使用了相同的公共池)但是为什么 always 强制这样做?

这是可以在更高版本中更改的内容之一吗?

有两个不同的方面。

首先,只有一个管道,它是顺序的或并行的。在内部流中选择顺序还是并行是无关紧要的。请注意,您在引用的代码片段中看到的 downstream 消费者代表整个后续流管道,因此在您的代码中,以 .collect(Collectors.toSet()); 结尾,此消费者最终会将结果元素添加到单个 Set 不是线程安全的实例。因此,与单个消费者并行处理内部流会破坏整个操作。

如果外部流被拆分,则引用的代码可能会在不同的消费者添加到不同的集合时被同时调用。这些调用中的每一个都会处理映射到不同内部流实例的外部流的不同元素。由于您的外部流仅包含单个元素,因此无法拆分。

这种方式已经实现,也是 issue, as forEach is called on the inner stream which will pass all elements to the downstream consumer. As demonstrated by 的原因,支持懒惰和子流拆分的替代实现是可能的。但这是一种根本不同的实现方式。 Stream 实现的当前设计主要由消费者组成,因此最终,源拆分器(以及从中分离出来的那些)接收一个 Consumer 代表整个流管道 tryAdvanceforEachRemaining。相反,链接答案的解决方案进行拆分器组合,生成一个新的 Spliterator 委托给源拆分器。我想,这两种方法都有优点,但我不确定,如果反过来,OpenJDK 实现会损失多少。

对于像我这样迫切需要并行化 flatMap 并且需要一些实用解决方案的人,而不仅仅是历史和理论。

我想出的最简单的解决方案是手动进行扁平化,基本上是用 map + reduce(Stream::concat) 代替它。

下面是一个演示如何执行此操作的示例:

@Test
void testParallelStream_NOT_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // does not parallelize nested streams
                .flatMap(i -> generateRangeParallel(i, 100))

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

@Test
void testParallelStream_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // concatenation of nested streams instead of flatMap, parallelizes ALL the items
                .map(i -> generateRangeParallel(i, 100))
                .reduce(Stream::concat).orElse(Stream.empty())

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

Stream<Integer> generateRangeParallel(int start, int num) {
    return Stream.iterate(start, i -> i + 1).limit(num).parallel();
}

// run this method with produced output to see how work was distributed
void countThreads(String strOut) {
    var res = Arrays.stream(strOut.split("\n"))
            .map(line -> line.split("\s+"))
            .collect(Collectors.groupingBy(s -> s[0], Collectors.counting()));
    System.out.println(res);
    System.out.println("threads  : " + res.keySet().size());
    System.out.println("work     : " + res.values());
}

我机器上 运行 的统计数据:

NOT_WORKING case stats:
{ForkJoinPool-1-worker-23=100, ForkJoinPool-1-worker-5=300}
threads  : 2
work     : [100, 300]

WORKING case stats:
{ForkJoinPool-1-worker-9=16, ForkJoinPool-1-worker-23=20, ForkJoinPool-1-worker-21=36, ForkJoinPool-1-worker-31=17, ForkJoinPool-1-worker-27=177, ForkJoinPool-1-worker-13=17, ForkJoinPool-1-worker-5=21, ForkJoinPool-1-worker-19=8, ForkJoinPool-1-worker-17=21, ForkJoinPool-1-worker-3=67}
threads  : 10
work     : [16, 20, 36, 17, 177, 17, 21, 8, 21, 67]