如何在Java中实现medians算法的medians



我正在尝试用Java实现中值算法。算法应确定一组数字的中值。我试图在维基百科上实现伪代码:

https://en.wikipedia.org/wiki/Median_of_medians

我得到一个缓冲区溢出,不知道为什么。由于重复出现,我很难跟踪代码。

import java.util.Arrays;
public class MedianSelector {
private static final int CHUNK = 5;

public static void main(String[] args) {
int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
lowerMedian(test);
System.out.print(Arrays.toString(test));
}

/**
* Computes and retrieves the lower median of the given array of
* numbers using the Median algorithm presented in the lecture.
* 
* @param input numbers.
* @return the lower median.
* @throw IllegalArgumentException if the array is {@code null} or empty.
*/
public static int lowerMedian(int[] numbers) {
if(numbers == null || numbers.length == 0) {
throw new IllegalArgumentException();
}

return numbers[select(numbers, 0, numbers.length - 1, (numbers.length - 1) / 2)];
}

private static int select(int[] numbers, int left, int right, int i) {

if(left == right) {
return left;
}

int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);

if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i); 
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}

private static int pivot(int numbers[], int left, int right) {
if(right - left < CHUNK) {
return partition5(numbers, left, right);
}

for(int i=left; i<=right; i=i+CHUNK) {
int subRight = i + (CHUNK-1);

if(subRight > right) {
subRight = right;
}

int medChunk = partition5(numbers, i, subRight);

int tmp = numbers[medChunk];
numbers[medChunk] = numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))];
numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))] = tmp;
}

int mid = (right - left) / 10 + left +1;
return select(numbers, left, (int) (left + Math.floor((right - left) / CHUNK)), mid);
}

private static int partition(int[] numbers, int left, int right, int idx, int k) {
int pivotVal = numbers[idx];
int storeIndex = left;
int storeIndexEq = 0;
int tmp = 0;

tmp = numbers[idx];
numbers[idx] = numbers[right];
numbers[right] = tmp;

for(int i=left; i<right; i++) {
if(numbers[i] < pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndex];
numbers[storeIndex] = tmp;
storeIndex++;
}
}

storeIndexEq = storeIndex;

for(int i=storeIndex; i<right; i++) {
if(numbers[i] == pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;
storeIndexEq++;
}
}

tmp = numbers[right];
numbers[right] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;

if(k < storeIndex) {
return storeIndex;
}

if(k <= storeIndexEq) {
return k;
}

return storeIndexEq;
}

//Insertion sort
private static int partition5(int[] numbers, int left, int right) {
int i = left + 1;
int j = 0;

while(i<=right) {
j= i;
while(j>left && numbers[j-1] > numbers[j]) {
int tmp = numbers[j-1];
numbers[j-1] = numbers[j];
numbers[j] = tmp;
j=j-1;
}
i++;
}

return left + (right - left) / 2;
}
}

确认n(在伪代码中(或i(在我的代码中(代表中值的位置?因此,我们假设我们的数组是数字={9,8,7,6,5,4,2,0}。我会打选择(数字,0,9,4(,对吗?

我不明白枢轴中点的计算?为什么要除以10?也许伪代码中有错误?

谢谢你的帮助。

EDIT:事实证明,从迭代到递归的切换只是转移注意力。OP确定的实际问题在第二个递归select调用的参数中。

此行:

return select(numbers, left, pivotIndex + 1, i);

应该是

return select(numbers, pivotIndex + 1, right, i);

我将把原来的答案留在下面,因为我不想看起来比实际聪明。


我认为您可能误解了select方法的伪代码——它使用迭代而不是递归。

以下是您当前的实现:

private static int select(int[] numbers, int left, int right, int i) {

if(left == right) {
return left;
}

int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);

if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i); 
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}

伪码

function select(list, left, right, n)
loop
if left = right then
return left
pivotIndex := pivot(list, left, right)
pivotIndex := partition(list, left, right, pivotIndex, n)
if n = pivotIndex then
return n
else if n < pivotIndex then
right := pivotIndex - 1
else
left := pivotIndex + 1

这通常使用while循环来实现:

private static int select(int[] numbers, int left, int right, int i) {
while(true)
{
if(left == right) {
return left;
}

int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);

if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
right = pivotIndex - 1; 
}else {
left = pivotIndex + 1;
}
}
}

有了这个更改,您的代码似乎可以工作,不过显然您需要进行测试才能确认。

int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
System.out.println("Lower Median: " + lowerMedian(test));
int[] check = test.clone();
Arrays.sort(check);
System.out.println(Arrays.toString(check));

输出:

Lower Median: 6
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13]

最新更新