如何找到 MergeSort 算法的最坏情况排列

How to find the worst case permutation for the MergeSort algorithm

我遇到递归错误,在重新分配递归限制时,我在尝试 运行 以下代码时遇到内存错误。

def join(A, left, right, l, m, r):
    x = 0
    for x in range(m-l):
        A[x] = left[x]
    for j in range(r-m):
        A[x+j] = right[j]enter code here

def split(A, left, right, l, m, r):
    for i in range(0, m-l, 1):
        left[i] = A[i*2]
    for i in range(0, r-m, 1):
        right[i] = A[i*2+1]

def generateWorstCase(A, l, r):
    if l < r:
        m = int(l + (r-1) / 2)
        left = [0 for i in range(m - l + 1)]
        right = [0 for i in range(r - m)]
        split(A, left, right, l, m, r)
        generateWorstCase(left, l, m)
        generateWorstCase(right, m+1, r)
        join(A, left, right, l, m, r)

arr = [1, 2, 3, 4, 5, 6, 7, 8]
generateWorstCase(arr, 0, len(arr)-1)
print(arr)

我试着翻译了 geeksforgeeks https://www.geeksforgeeks.org/find-a-permutation-that-causes-worst-case-of-merge-sort/ 给出的例子,但我仍然对在 python 中编写代码感到困惑。我了解它的工作原理(因为它会导致 mergeSort 算法比较最高数量)。我感谢任何有助于解决此问题的提示。

修正主要错误

你计算指数的方式是错误的。

m = int(l + (r-1) / 2)

让我们用实际数字来试试这个;例如:

l = 100
r = 110
m = ?    # should be in the middle, maybe 104 or 105?

m = int(l + (r-1)/2) 
m = int(100 + 109/2)
m = int(100 + 54.5)
m = 154  # wrong

这只是括号错误。要修复它:

m = (l + r) // 2
m = (100 + 110) // 2
m = 105

请注意,使用 a // bint(a / b) 更好。运算符/是python3中的浮点除法。运算符//是整数除法。我们这里不需要浮点数,所以坚持使用整数除法。

一般调试建议

下次您运行遇到类似问题时,我建议您尝试自己测试代码。我知道三种方法:手动,或使用 print,或使用调试器。

手动

拿起笔和纸。在你的纸上,写下一个小数组 A,可能有 6 个元素。记下 l = 0, r = len(A) - 1 = 5。然后阅读你的代码并在你的脑海中执行它,就像你是一台计算机一样,在你的纸上做笔记。当你阅读 m = int(l + (r-1) / 2) 时,将结果 m = 154 写在你的纸上。当你到达递归调用generateWorstCase(left, l, m)时,画一条水平线,重新开始递归调用:A = [...], l = 0, r = ... 由于数组 A 足够小,您应该能够手动 运行 整个算法,以排序数组结束,或者在出现错误时注意到(例如 m是 154 而不是 104 或 105)。

print

在您的代码中添加对 print 的调用,以打印变量在执行过程中获取的连续值,并找出何时出现问题。首先添加一些打印件,如果这还不足以找出问题所在,请添加更多打印件。打印越来越多,直到您能弄清楚问题何时出现。

例如:

def generateWorstCase(A, l, r, depth=0):
    print('  '*depth, 'generateWorstCase', 'A=', A, '; l=', l, '; r=', r)
    if l < r:
        m = int(l + (r-1) / 2)
        print('  '*depth, '                 ', 'm=', m)
        left = [0 for i in range(m - l + 1)]
        right = [0 for i in range(r - m)]
        split(A, left, right, l, m, r)
        generateWorstCase(left, l, m, depth+1)
        generateWorstCase(right, m+1, r, depth+1)
        join(A, left, right, l, m, r)

使用调试器

存在称为“调试器”的程序可以自动完成整个过程:它们执行代码的速度非常慢,允许您在执行期间暂停,在执行期间显示每个变量的值,以及许多其他很酷的东西来帮助您更好地了解正在发生的事情并找到您的错误。

修复你的功能join

您的函数 join 不正确。它只是连接两个数组 leftright 而没有做任何困难的工作。我想指出一些关于合并排序和快速排序的重要内容。如果我们总结这两种算法,它们非常相似:

Sort(a):
    split a in two halves
    recursively sort first half
    recursively sort second half
    merge the two halves

那么归并排序和快速排序有什么区别呢?区别在于实际工作发生的地方:

  • quicksort在拆分时比较元素,使得前半部分的所有元素都小于后半部分的所有元素;那么这两部分可以简单地连接起来。
  • 在归并排序中,数组是可以随机拆分的,只要每一半中有大约一半的元素即可;合并时比较元素,因此合并两个已排序的一半会产生一个已排序的数组。

简而言之:

  • 在快速排序中,split 完成工作,join 是微不足道的;
  • 在合并排序中,split 很简单,merge 完成了工作。

现在,在您的代码中,join 函数只是连接两半。那是错误的。应该比较元素。事实上,如果我们查看您的整个代码,将永远不会对任何元素进行任何比较。因此,列表不可能被正确排序。摆弄索引对列表排序没有任何作用。在某些时候,您必须比较元素,例如 if a[i] < a[j]:if left[i] < right[j]:;否则,您的算法将如何找到哪些元素较大,哪些元素较小,以便对数组进行排序?

最终代码

Python 有很多处理列表的工具,例如切片、列表推导或在不实际引用索引的情况下遍历列表元素。使用这些,将列表分成两半变得容易得多。这特别容易,因为对于归并排序算法,哪个元素最终在哪一半并不重要,所以你有很大的自由度。

这是对您的代码进行修改的示例:

def split(a):
    m = len(a) // 2 
    left = a[:m]
    right = a[m:]
    return left, right

def merge(a, left, right):
    li = 0
    ri = 0
    i = 0
    while li < len(left) and ri < len(right):
        if left[li] < right[ri]:
            a[i] = left[li]
            li += 1
        else:
            a[i] = right[ri]
            ri += 1
        i += 1
    while li < len(left):
        a[i] = left[li]
        li += 1
        i += 1
    while ri < len(right):
        a[i] = right[ri]
        ri += 1
        i += 1

def mergesort(a):
    if len(a) > 1:
        left, right = split(a)
        mergesort(left)
        mergesort(right)
        merge(a, left, right)

测试:

a = [12, 3, 7, 8, 5, 4, 9, 1, 0]
print(a)
# [12, 3, 7, 8, 5, 4, 9, 1, 0]
mergesort(a)
print(a)
# [0, 1, 3, 4, 5, 7, 8, 9, 12]

正如我所提到的,为了合并排序的目的,你可以按照你想要的方式拆分数组,这并不重要。只有合并需要仔细进行。因此,split 函数有两个替代方案:

def split(a):
    m = len(a) // 2 
    left = a[:m]
    right = a[m:]
    return left, right

def split(a):
    even = a[::2]
    odd = a[1::2]
    return even, odd

我强烈建议您通过在代码中添加 print 来找出元素移动的方式,从而找出这两个版本 split 之间的区别。

主要问题是这里的错字:m = int(l + (r-1) / 2).

不要使用 l 作为标识符,因为它在许多字体中看起来与 1 相似,容易混淆。计算中间指数的正确公式是:

    m = l + (r-l) // 2

请注意,使用整数除法 // 而不是 / 可以避免转换 int()

join 函数中还有另一个错误:for x in range(m-l) 会忘记切片中的最后一项,因为 m 被包含而不是被排除。包含切片边界的合并排序实现中无处不在的约定令人困惑,导致像这个这样的错误。考虑使用 r 作为第一个元素 切片之后的索引。

代码中还有更多问题,即临时数组和数组切片之间的混淆A。仅使用临时数组进行推理更简单。

这是一个简化版本:

def generateWorstCase(A):
    n = len(A)
    if n > 1:
        m = n // 2
        left = [A[i*2] for i in range(n-m)]
        right = [A[i*2+1] for i in range(m)]
        A = generateWorstCase(left) + generateWorstCase(right)
    return A

arr = [1, 2, 3, 4, 5, 6, 7, 8]
print(generateWorstCase(arr))

输出:[1, 5, 3, 7, 2, 6, 4, 8]

可以利用 Python 的卓越表现力进一步简化此代码:

def generateWorstCase(A):
    return A if len(A) <= 1 else generateWorstCase(A[::2]) + generateWorstCase(A[1::2])

arr = [1, 2, 3, 4, 5, 6, 7, 8]
print(generateWorstCase(arr))