已排序数组的快速排序堆栈溢出(适用于其他数据集)

QuickSort stack overflow for sorted arrays (works for other data sets)

所以我尽力优化我的 Quicksort 算法以尽可能高效地 运行 ,即使对于排序或接近排序的数组,使用三个值的中位数的枢轴,也使用插入排序对于小分区大小。我已经针对大型随机值数组测试了我的代码并且它可以工作,但是当我传递一个已经排序的数组时,我得到了一个堆栈溢出错误(具有讽刺意味的是它让我找到了这个网站)。我认为这是我的递归调用的问题(我知道分区至少适用于其他数据集),但我不太明白要更改什么。

这是我第一学期数据结构的一部分 class 所以任何代码审查也会有所帮助。谢谢

public void quickSort(ArrayList<String> data, int firstIndex, int numberToSort) {
    if (firstIndex < (firstIndex + numberToSort - 1))
        if (numberToSort < 16) {
            insertionSort(data, firstIndex, numberToSort);
        } else {
            int pivot = partition(data, firstIndex, numberToSort);
            int leftSegmentSize = pivot - firstIndex;
            int rightSegmentSize = numberToSort - leftSegmentSize - 1;
            quickSort(data, firstIndex, leftSegmentSize);
            quickSort(data, pivot + 1, rightSegmentSize);
        }
}



public int partition(ArrayList<String> data, int firstIndex, int numberToPartition) {
    int tooBigNdx = firstIndex + 1;
    int tooSmallNdx = firstIndex + numberToPartition - 1;

    String string1 = data.get(firstIndex);
    String string2 = data.get((firstIndex + (numberToPartition - 1)) / 2);
    String string3 = data.get(firstIndex + numberToPartition - 1);
    ArrayList<String> randomStrings = new ArrayList<String>();
    randomStrings.add(string1);
    randomStrings.add(string2);
    randomStrings.add(string3);
    Collections.sort(randomStrings);
    String pivot = randomStrings.get(1);
    if (pivot == string2) {
        Collections.swap(data, firstIndex, (firstIndex + (numberToPartition - 1)) / 2);
    }
    if (pivot == string3) {
        Collections.swap(data, firstIndex, firstIndex + numberToPartition - 1);
    }
    while (tooBigNdx < tooSmallNdx) {
        while ((tooBigNdx < tooSmallNdx) && (data.get(tooBigNdx).compareTo(pivot) <= 0)) {
            tooBigNdx++;
        }
        while ((tooSmallNdx > firstIndex) && (data.get(tooSmallNdx).compareTo(pivot) > 0)) {
            tooSmallNdx--;
        }
        if (tooBigNdx < tooSmallNdx) {// swap
            Collections.swap(data, tooSmallNdx, tooBigNdx);
        }
    }
    if (pivot.compareTo(data.get(tooSmallNdx)) >= 0) {
        Collections.swap(data, firstIndex, tooSmallNdx);
        return tooSmallNdx;
    } else {
        return firstIndex;
    }
}

在您的 partition 方法中,您有时会使用超出范围的元素:

String string1 = data.get(firstIndex);
String string2 = data.get((firstIndex + (numberToPartition - 1)) / 2);
String string3 = data.get(firstIndex + numberToPartition - 1);

(firstIndex + (numberToPartition - 1)) / 2 不是中间元素的索引。那将是 (firstIndex + (firstIndex + (numberToPartition - 1))) / 2

= firstIndex + ((numberToPartition - 1) / 2).

事实上,如果 firstIndex > n/2(其中 n 是输入中的元素数),您使用的元素索引小于 firstIndex。对于排序数组,这意味着您选择 firstIndex 处的元素作为主元。因此你得到一个递归深度

<code>Omega(n)</code>,

对于足够大的输入会导致堆栈溢出。

您无需过多更改算法即可避免堆栈溢出。诀窍是在最大的分区上进行尾调用优化,而只在最小的分区上使用递归。这通常意味着您必须将 if 更改为 while。我现在无法真正测试 java 代码,但它应该类似于:

public void quickSort(ArrayList<String> data, int firstIndex, int numberToSort) {
    while (firstIndex < (firstIndex + numberToSort - 1))
        if (numberToSort < 16) {
            insertionSort(data, firstIndex, numberToSort);
        } else {
            int pivot = partition(data, firstIndex, numberToSort);
            int leftSegmentSize = pivot - firstIndex;
            int rightSegmentSize = numberToSort - leftSegmentSize - 1;

            //only use recursion for the smallest partition
            if (leftSegmentSize < rightSegmentSize) {
                quickSort(data, firstIndex, leftSegmentSize);
                firstIndex = pivot + 1;
                numberToSort = rightSegmentSize;
            } else {
                quickSort(data, pivot + 1, rightSegmentSize);
                numberToSort = leftSegmentSize;
            }
        }
}

这确保了调用堆栈的大小最多为 O(log n),因为在每次调用中,您仅对最多 n/2 大小的数组使用递归。