如何在 Java 中实现多线程 MergeSort

How to implement a multi-threaded MergeSort in Java

我在单个线程中找到的大多数合并排序示例 运行。这首先破坏了使用合并排序算法的一些优势。有人可以展示使用多线程在 java 中编写合并排序算法的正确方法吗?

该解决方案应使用最新版本 java 的功能(如适用)。 Whosebug 上已有的许多解决方案都使用纯线程。我正在寻找一个演示 ForkJoin 与 RecursiveTask 的解决方案,这似乎是 RecursiveTask class 的主要用例。

重点应放在展示具有卓越性能特征的算法上,包括可能的时间和 space 复杂性。

注意:提出的重复问题都不适用,因为它们都没有提供使用递归任务的解决方案,而这正是这个问题所要求的。

合并排序最方便的多线程范例是 fork-join 范例。这是从 Java 8 及更高版本提供的。以下代码演示了使用 fork-join 的合并排序。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = new ArrayList<>(elements);
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return this.elements;
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        List<N> sorted = new ArrayList<>();
        while(!left.isEmpty() || !right.isEmpty()) {
            if(left.isEmpty())
                sorted.add(right.remove(0));
            else if(right.isEmpty())
                sorted.add(left.remove(0));
            else {
                if( left.get(0).compareTo(right.get(0)) < 0 )
                    sorted.add(left.remove(0));
                else
                    sorted.add(right.remove(0));
            }
        }

        return sorted;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,10,1)));
        System.out.println("result: " + result);
    }
}

虽然不那么直接,但以下代码变体消除了 ArrayList 的过度复制。初始未排序列表仅创建一次,对子列表的调用本身不需要执行任何复制。在每次算法分叉时我们都会复制数组列表。此外,现在,当合并列表而不是创建新列表并在每次重用左侧列表并将我们的值插入其中时复制其中的值。通过避免额外的复制步骤,我们提高了性能。我们在这里使用 LinkedList 是因为与 ArrayList 相比,插入的成本相当低。我们还消除了对 remove 的调用,这在 ArrayList 上也是昂贵的。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return new LinkedList<>(this.elements);
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        int leftIndex = 0;
        int rightIndex = 0;
        while(leftIndex < left.size() || rightIndex < right.size()) {
            if(leftIndex >= left.size())
                left.add(leftIndex++, right.get(rightIndex++));
            else if(rightIndex >= right.size())
                return left;
            else {
                if( left.get(leftIndex).compareTo(right.get(rightIndex)) < 0 )
                    leftIndex++;
                else
                    left.add(leftIndex++, right.get(rightIndex++));
            }
        }

        return left;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
        System.out.println("result: " + result);
    }
}

我们还可以通过使用迭代器来进一步改进代码,而不是在执行合并时直接调用 get。这样做的原因是通过索引获取 LinkedList 的时间性能较差(线性),因此通过使用迭代器,我们消除了在每次获取时内部迭代链表所导致的减速。迭代器上对 next 的调用是常数时间,而不是调用 get 的线性时间。下面的代码被修改为使用迭代器。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return new LinkedList<>(this.elements);
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        ListIterator<N> leftIter = left.listIterator();
        ListIterator<N> rightIter = right.listIterator();
        while(leftIter.hasNext() || rightIter.hasNext()) {
            if(!leftIter.hasNext()) {
                leftIter.add(rightIter.next());
                rightIter.remove();
            }
            else if(!rightIter.hasNext())
                return left;
            else {
                N rightElement = rightIter.next();
                if( leftIter.next().compareTo(rightElement) < 0 )
                    rightIter.previous();
                else {
                    leftIter.previous();
                    leftIter.add(rightElement);
                }
            }
        }

        return left;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
        System.out.println("result: " + result);
    }
}

最后是最复杂的代码版本,这次迭代使用了完全就地操作。仅创建初始 ArrayList,并且不会创建其他集合。因此逻辑特别难以遵循(所以我把它留到最后)。但应该尽可能接近理想的实现。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return this.elements;
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            merge(left, right);
            return this.elements;
        }
    }

    private void merge(List<N> left, List<N> right) {
        int leftIndex = 0;
        int rightIndex = 0;
        while(leftIndex < left.size() ) {
            if(rightIndex == 0) {
                if( left.get(leftIndex).compareTo(right.get(rightIndex)) > 0 ) {
                    swap(left, leftIndex++, right, rightIndex++);
                } else {
                    leftIndex++;
                }
            } else {
                if(rightIndex >= right.size()) {
                    if(right.get(0).compareTo(left.get(left.size() - 1)) < 0 )
                        merge(left, right);
                    else
                        return;
                }
                else if( right.get(0).compareTo(right.get(rightIndex)) < 0 ) {
                    swap(left, leftIndex++, right, 0);
                } else {
                    swap(left, leftIndex++, right, rightIndex++);
                }
            }
        }

        if(rightIndex < right.size() && rightIndex != 0)
            merge(right.subList(0, rightIndex), right.subList(rightIndex, right.size()));
    }

    private void swap(List<N> left, int leftIndex, List<N> right, int rightIndex) {
        //N leftElement = left.get(leftIndex);
        left.set(leftIndex, right.set(rightIndex, left.get(leftIndex)));
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(new ArrayList<>(Arrays.asList(5,9,8,7,6,1,2,3,4))));
        System.out.println("result: " + result);
    }
}