三向归并排序问题排序不正确

Three way merge sort problem not sorting correctly

我一直在研究这种基于我的正常合并排序代码的三向合并排序算法;但是,它没有正确排序,所以我相信我的代码中可能存在一个小错误。有什么帮助吗?我已经研究了 3 个小时的代码,试图找到问题所在,但事实证明这很困难。

public class TriMergeSort {

    void merge(int arr[], int low, int mid1, int mid2, int high) { 
        int sizeA = mid1 - low + 1; 
        int sizeB =  mid2 - mid1;
        int sizeC = high - mid2;

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 
        int C[] = new int[sizeC];

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid1 + j + 1]; 
        for (int x = 0; x < sizeC; x++) 
            C[x] = arr[mid2 + x + 1];

        int i = 0, j = 0, x = 0; 
        int k = low; 
        
        while (i < sizeA && j < sizeB && x < sizeC) {
            
            if (A[i] < B[j] && A[i] < C[x]) { 
                arr[k] = A[i]; 
                i++; 
            } else
            if (A[i] >= B[j] && B[j] < C[x]) { 
                arr[k] = B[j]; 
                j++; 
            } else
            if (A[i] > C[x] && B[j] >= C[x]) { 
                arr[k] = C[x]; 
                x++; 
            } 
            k++; 
        } 

        while (i < sizeA) { 
            arr[k] = A[i]; 
            i++; 
            k++; 
        } 

        while (j < sizeB) { 
            arr[k] = B[j]; 
            j++; 
            k++; 
        } 
        
        while (x < sizeC) { 
            arr[k] = C[x]; 
            x++; 
            k++; 
        }
    } 

    void sort(int arr[], int low, int high) { 
        
        if (low < high) {  
            int mid1 = low + ((high - low) / 3); 
            int mid2 = low + 2 * ((high - low) / 3) + 1;

            sort(arr, low, mid1); 
            sort(arr, mid1 + 1, mid2); 
            sort(arr, mid2 + 1, high);

            merge(arr, low, mid1, mid2, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) 
            System.out.print(arr[i] + " "); 
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 }; 

        TriMergeSort test = new TriMergeSort(); 
        test.sort(arr, 0, arr.length - 1); 

        print(arr); 
    }
} 

问题中 posted 的代码工作正常。您没有 post 您遇到问题的 3 向合并代码。

请注意,不应将 high 作为要排序的切片中最后一项的索引传递,而应传递切片之外的第一个元素的索引。这允许更简单的代码,而不会造成混乱和容易出错的 +1/-1 调整。

这是修改后的版本:

public class MergeSort { 

    void merge(int arr[], int low, int mid, int high) { 
        int sizeA = mid - low; 
        int sizeB = high - mid; 

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid + j]; 

        int i = 0, j = 0; 
        int k = low; 
        
        while (i < sizeA && j < sizeB) { 
            if (A[i] <= B[j]) { 
                arr[k++] = A[i++]; 
            } else { 
                arr[k++] = B[j++]; 
            } 
        } 

        while (i < sizeA) {
            arr[k++] = A[i++];
        } 

        while (j < sizeB) { 
            arr[k++] = B[j++];
        } 
    } 

    void sort(int arr[], int low, int high) { 
        if (high - low >= 2) {  
            int mid = low + (high - low) / 2; 
            sort(arr, low, mid); 
            sort(arr, mid, high); 
            merge(arr, low, mid, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) {
            System.out.print(arr[i] + " ");
        }
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12, 10, 59 }; 
        MergeSort test = new MergeSort(); 
        test.sort(arr, 0, arr.length); 
        print(arr); 
    } 
}

要将其转换为 3 路合并版本,sort3 必须遵循以下步骤:

  • 将范围分成 3 个切片而不是 2 个。第一个切片从 lowmid1 = low + (high - low)/3 排除,第二个切片从 mid1mid2 = low + (high - low)*2/3 排除,并且从 mid2high 第三个排除。
  • 对 3 个子切片中的每一个进行递归排序
  • 致电merge3(arr, low, mid1, mid2, high)
    • 复制 3 个子切片
    • 为 3 个索引值编写一个循环 运行 3 个切片,直到其中一个用完
    • 为剩下的2个切片(A和B)或(B和C)或(A和C)写3个循环,
    • 写3个循环从剩余切片A、B或C中复制剩余元素

编辑: TriMergeSort class 中的 merge 函数缺少 3 个循环,这些循环将 2 个切片合并为 3 个切片之一初始切片已用完。这解释了为什么数组没有得到正确排序。在 3 路合并循环之后,你应该有:

    while (i < sizeA && j < sizeB) {
        ...
    }
    while (i < sizeA && x < sizeC) {
        ...
    }
    while (j < sizeB && x < sizeC) {
        ...
    }

为了避免所有这些重复循环,您可以将对索引值的测试合并到一个循环体中:

public class TriMergeSort {

    void merge(int arr[], int low, int mid1, int mid2, int high) { 
        int sizeA = mid1 - low; 
        int sizeB = mid2 - mid1;
        int sizeC = high - mid2;

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 
        int C[] = new int[sizeC];

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid1 + j]; 
        for (int k = 0; k < sizeC; k++) 
            C[k] = arr[mid2 + k];

        int i = 0, j = 0, k = 0;
        
        while (low < high) {
            if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
                if (k >= sizeC || A[i] <= C[k]) {
                    arr[low++] = A[i++];
                } else {
                    arr[low++] = C[k++];
                }
            } else {
                if (j < sizeB && (k >= sizeC || B[j] <= C[k])) {
                    arr[low++] = B[j++];
                } else {
                    arr[low++] = C[k++];
                }
            }
        } 
    } 

    void sort(int arr[], int low, int high) { 
        if (high - low >= 2) {  
            int mid1 = low + (high - low) / 3; 
            int mid2 = low + (high - low) * 2 / 3;
            sort(arr, low, mid1); 
            sort(arr, mid1, mid2); 
            sort(arr, mid2, high);
            merge(arr, low, mid1, mid2, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) {
            System.out.print(arr[i] + " ");
        }
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 }; 
        TriMergeSort test = new TriMergeSort(); 
        test.sort(arr, 0, arr.length); 
        print(arr); 
    }
}

上面的 while 循环可以进一步简化,但可读性稍差:

    while (low < high) {
        if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
            arr[low++] = (k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++];
        } else {
            arr[low++] = (j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
        }
    } 

甚至更进一步:

    while (low < high) {
        arr[low++] = (i < sizeA && (j >= sizeB || A[i] <= B[j])) ?
            ((k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++]) :
            (j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
    }