如何在 Java 中实现中位数算法

How to implement the medians of medians algorithm in Java

我正在尝试在 Java 中实现中位数算法。该算法应确定一组数字的中位数。我尝试在维基百科上实现伪代码:

https://en.wikipedia.org/wiki/Median_of_medians

我遇到了缓冲区溢出,但不知道为什么。由于递归,我很难跟踪代码。

    import java.util.Arrays;

public class MedianSelector {
    private static final int CHUNK = 5;
    
    public static void main(String[] args) {
        int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
        lowerMedian(test);
        System.out.print(Arrays.toString(test));
    }
    
    /**
     * Computes and retrieves the lower median of the given array of
     * numbers using the Median algorithm presented in the lecture.
     * 
     * @param input numbers.
     * @return the lower median.
     * @throw IllegalArgumentException if the array is {@code null} or empty.
    */
    public static int lowerMedian(int[] numbers) {
        if(numbers == null || numbers.length == 0) {
            throw new IllegalArgumentException();
        }
        
        return numbers[select(numbers, 0, numbers.length - 1, (numbers.length - 1) / 2)];
    }
    
    private static int select(int[] numbers, int left, int right, int i) {
        
        if(left == right) {
            return left;
        }
        
        int pivotIndex = pivot(numbers, left, right);
        pivotIndex = partition(numbers, left, right, pivotIndex, i);
        
        if(i == pivotIndex) {
            return i;
        }else if(i < pivotIndex) {
            return select(numbers, left, pivotIndex - 1, i); 
        }else {
            return select(numbers, left, pivotIndex + 1, i);
        }
    }
    
    private static int pivot(int numbers[], int left, int right) {
        if(right - left < CHUNK) {
            return partition5(numbers, left, right);
        }
        
        for(int i=left; i<=right; i=i+CHUNK) {
            int subRight = i + (CHUNK-1);
            
            if(subRight > right) {
                subRight = right;
            }
            
            int medChunk = partition5(numbers, i, subRight);
                    
            int tmp = numbers[medChunk];
            numbers[medChunk] = numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))];
            numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))] = tmp;
        }
        
        int mid = (right - left) / 10 + left +1;
        return select(numbers, left, (int) (left + Math.floor((right - left) / CHUNK)), mid);
    }
    
    private static int partition(int[] numbers, int left, int right, int idx, int k) {
        int pivotVal = numbers[idx];
        int storeIndex = left;
        int storeIndexEq = 0;
        int tmp = 0;
        
        tmp = numbers[idx];
        numbers[idx] = numbers[right];
        numbers[right] = tmp;
        
        for(int i=left; i<right; i++) {
            if(numbers[i] < pivotVal) {
                tmp = numbers[i];
                numbers[i] = numbers[storeIndex];
                numbers[storeIndex] = tmp;
                storeIndex++;
            }
        }
        
        storeIndexEq = storeIndex;
        
        for(int i=storeIndex; i<right; i++) {
            if(numbers[i] == pivotVal) {
                tmp = numbers[i];
                numbers[i] = numbers[storeIndexEq];
                numbers[storeIndexEq] = tmp;
                storeIndexEq++;
            }
        }
        
        tmp = numbers[right];
        numbers[right] = numbers[storeIndexEq];
        numbers[storeIndexEq] = tmp;
        
        if(k < storeIndex) {
            return storeIndex;
        }
        
        if(k <= storeIndexEq) {
            return k;
        }
           
        return storeIndexEq;
    }
    
    //Insertion sort
    private static int partition5(int[] numbers, int left, int right) {
        int i = left + 1;
        int j = 0;
        
        while(i<=right) {
            j= i;
            while(j>left && numbers[j-1] > numbers[j]) {
                int tmp = numbers[j-1];
                numbers[j-1] = numbers[j];
                numbers[j] = tmp;
                j=j-1;
            }
            i++;
        }
        
        return left + (right - left) / 2;
    }
}

确认n(在伪代码中)或i(在我的代码中)代表中位数的位置?所以让我们假设我们的数组是 number = {9,8,7,6,5,4,3,2,1,0}。我会调用 select{numbers, 0, 9,4), 对吗?

我不懂pivot中mid的计算?为什么要除以 10?也许伪代码有错误?

感谢您的帮助。

编辑:事实证明,从迭代到递归的切换是一个转移注意力的问题。 OP 确定的实际问题出现在第二次递归 select 调用的参数中。

这一行:

return select(numbers, left, pivotIndex + 1, i);

应该是

return select(numbers, pivotIndex + 1, right, i);

我会在下面留下原来的答案,因为我不想显得比实际聪明。


我认为您可能误解了 select 方法的伪代码 - 它使用迭代而不是递归。

这是您当前的实现:

private static int select(int[] numbers, int left, int right, int i) {
    
    if(left == right) {
        return left;
    }
    
    int pivotIndex = pivot(numbers, left, right);
    pivotIndex = partition(numbers, left, right, pivotIndex, i);
    
    if(i == pivotIndex) {
        return i;
    }else if(i < pivotIndex) {
        return select(numbers, left, pivotIndex - 1, i); 
    }else {
        return select(numbers, left, pivotIndex + 1, i);
    }
}

和伪代码

function select(list, left, right, n)
    loop
        if left = right then
            return left
        pivotIndex := pivot(list, left, right)
        pivotIndex := partition(list, left, right, pivotIndex, n)
        if n = pivotIndex then
            return n
        else if n < pivotIndex then
            right := pivotIndex - 1
        else
            left := pivotIndex + 1

这通常使用 while 循环来实现:

  private static int select(int[] numbers, int left, int right, int i) {
      while(true)
      {
          if(left == right) {
              return left;
          }
          
          int pivotIndex = pivot(numbers, left, right);
          pivotIndex = partition(numbers, left, right, pivotIndex, i);
          
          if(i == pivotIndex) {
              return i;
          }else if(i < pivotIndex) {
              right = pivotIndex - 1; 
          }else {
              left = pivotIndex + 1;
          }
      }
  }

通过此更改,您的代码似乎可以正常工作,但显然您需要进行测试以确认。

int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
System.out.println("Lower Median: " + lowerMedian(test));

int[] check = test.clone();
Arrays.sort(check);
System.out.println(Arrays.toString(check));

输出:

Lower Median: 6
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13]