Java快速排序优化



我有三个类。每一个都创建一个包含1000个int值的数组。

  • 类a:使用快速排序
  • 类b:使用快速排序直到每个分区的大小为10,然后执行InsertSort对较小的分区进行排序。
  • 类c:(这是我遇到麻烦的):与类b相同,除了对整个几乎排序的数组执行InsertSort。

似乎c类只是b类代码的一个微小变化(基本上是添加到a类),我只是真的不知道如何做到这一点…的帮助!提前感谢……

a类:

import java.util.Arrays;
import java.util.Random;
public class QuickSort {
private static final Random random = new Random();
private static final int RANDOM_INT_RANGE = 9999;
private static int[] randomArray(int size) {
    // Randomize data (array)
    final int[] arr = new int[size];
    for (int i = 0; i < arr.length; i++) {
        arr[i] = random.nextInt(RANDOM_INT_RANGE);
    }
    return arr;
}
// Sort
private static void sort(int[] arr) {
    if (arr.length > 0)
        sortInPlace(arr, 0, arr.length - 1);
}
private static void sortInPlace(int[] arr, int left, int right) {
    if (left >= right)
        return; // sorted
    final int range = right - left + 1;
    int pivot = random.nextInt(range) + left;
    int newPivot = partition(arr, left, right, pivot);
    sortInPlace(arr, left, newPivot - 1);
    sortInPlace(arr, newPivot + 1, right);
}
private static int partition(int[] arr, int left, int right, int pivot) {
    int pivotVal = arr[pivot];
    swapArrayVals(arr, pivot, right);
    int storeIndex = left;
    for (int i = left; i <= (right - 1); i++) {
        if (arr[i] < pivotVal) {
            swapArrayVals(arr, i, storeIndex);
            storeIndex++;
        }
    }
    swapArrayVals(arr, storeIndex, right);
    return storeIndex;
}
private static void swapArrayVals(int[] arr, int from, int to) {
    int fromVal = arr[from];
    int toVal = arr[to];
    arr[from] = toVal;
    arr[to] = fromVal;
}
public static void main(String[] args) {
    long StartTime = System.nanoTime();
    // Array size
    int[] arr = randomArray(1000);
    int[] copy = Arrays.copyOf(arr, arr.length);
    // Print original data (array)
    System.out.println("The starting/unsorted array: n"
            + Arrays.toString(arr));
    sort(arr);
    // check the result
    Arrays.sort(copy);
    if (Arrays.equals(arr, copy)) {
        System.out.println("The ending/sorted array: n"
                + Arrays.toString(arr));
        // print time
        long TotalTime = System.nanoTime() - StartTime;
        System.out.println("Total elapsed time (milliseconds) " + "is: "
                + TotalTime);
    }
}
}

b类:

import java.util.Arrays;
import java.util.Random;
public class OptQSort1 {
private static final Random random = new Random();
private static final int RANDOM_INT_RANGE = 9999;
private static int[] randomArray(int size) {
    // Randomize data (array)
    final int[] arr = new int[size];
    for (int i = 0; i < arr.length; i++) {
        arr[i] = random.nextInt(RANDOM_INT_RANGE);
    }
    return arr;
}
// Sort
private static void sort(int[] arr) {
    if (arr.length > 0)
        sortInPlace(arr, 0, arr.length - 1);
}
private static void sortInPlace(int[] arr, int left, int right) {
    boolean insertionSortCalled = false;
    // OptQSort1:
    int size = right - left + 1;
    if (size < 10 && !insertionSortCalled) {
        insertionSortCalled = true;
        insertionSort(arr, 0, arr.length - 1);
    }
    if (left >= right)
        return; // sorted
    final int range = right - left + 1;
    int pivot = random.nextInt(range) + left;
    int newPivot = partition(arr, left, right, pivot);
    sortInPlace(arr, left, newPivot - 1);
    sortInPlace(arr, newPivot + 1, right);
}
private static int partition(int[] arr, int left, int right, int pivot) {
    int pivotVal = arr[pivot];
    swapArrayVals(arr, pivot, right);
    int storeIndex = left;
    for (int i = left; i <= (right - 1); i++) {
        if (arr[i] < pivotVal) {
            swapArrayVals(arr, i, storeIndex);
            storeIndex++;
        }
    }
    swapArrayVals(arr, storeIndex, right);
    return storeIndex;
}
private static void swapArrayVals(int[] arr, int from, int to) {
    int fromVal = arr[from];
    int toVal = arr[to];
    arr[from] = toVal;
    arr[to] = fromVal;
}
public static void insertionSort(int[] arr, int left, int right) {
    int in, out;
    for (out = left + 1; out <= right; out++) {
        int temp = arr[out];
        in = out;
        while (in > left && arr[in - 1] >= temp) {
            arr[in] = arr[in - 1];
            --in;
        }
        arr[in] = temp;
    }
}
public static void main(String[] args) {
    long StartTime = System.nanoTime();
    // Array size
    int[] arr = randomArray(1000);
    int[] copy = Arrays.copyOf(arr, arr.length);
    // Print original data (array)
    System.out.println("The starting/unsorted array: n"
            + Arrays.toString(arr));
    sort(arr);
    // check the result
    Arrays.sort(copy);
    if (Arrays.equals(arr, copy)) {
        System.out.println("The ending/sorted array: n"
                + Arrays.toString(arr));
        // print time
        long TotalTime = System.nanoTime() - StartTime;
        System.out.println("Total elapsed time (milliseconds) " + "is: "
                + TotalTime);
    }
}
}
c类:
import java.util.Arrays;
import java.util.Random;
public class OptQSort2 {
private static final Random random = new Random();
private static final int RANDOM_INT_RANGE = 9999;
private static int[] randomArray(int size) {
    // Randomize data (array)
    final int[] arr = new int[size];
    for (int i = 0; i < arr.length; i++) {
        arr[i] = random.nextInt(RANDOM_INT_RANGE);
    }
    return arr;
}
// Sort
private static void sort(int[] arr) {
    if (arr.length > 0)
        sortInPlace(arr, 0, arr.length - 1);
    insertionSort(arr, 0, arr.length - 1);
}
private static void sortInPlace(int[] arr, int left, int right) {
    // OptQSort2:
    int size = right - left + 1;
    if (size < 10)
        return;
    if (left >= right)
        return; // sorted
    final int range = right - left + 1;
    int pivot = random.nextInt(range) + left;
    int newPivot = partition(arr, left, right, pivot);
    sortInPlace(arr, left, newPivot - 1);
    sortInPlace(arr, newPivot + 1, right);
}
private static int partition(int[] arr, int left, int right, int pivot) {
    int pivotVal = arr[pivot];
    swapArrayVals(arr, pivot, right);
    int storeIndex = left;
    for (int i = left; i <= (right - 1); i++) {
        if (arr[i] < pivotVal) {
            swapArrayVals(arr, i, storeIndex);
            storeIndex++;
        }
    }
    swapArrayVals(arr, storeIndex, right);
    return storeIndex;
}
private static void swapArrayVals(int[] arr, int from, int to) {
    int fromVal = arr[from];
    int toVal = arr[to];
    arr[from] = toVal;
    arr[to] = fromVal;
}
public static void insertionSort(int[] arr, int left, int right) {
    int in, out;
    for (out = left + 1; out <= right; out++) {
        int temp = arr[out];
        in = out;
        while (in > left && arr[in - 1] >= temp) {
            arr[in] = arr[in - 1];
            --in;
        }
        arr[in] = temp;
    }
}
public static void main(String[] args) {
    // Start the clock
    long StartTime = System.nanoTime();
    // Array size
    int[] arr = randomArray(1000);
    int[] copy = Arrays.copyOf(arr, arr.length);
    // Print original data (array)
    System.out.println("The starting/unsorted array: n"
            + Arrays.toString(arr));
    sort(arr);
    // check the result
    Arrays.sort(copy);
    if (Arrays.equals(arr, copy)) {
        System.out.println("The ending/sorted array: n"
                + Arrays.toString(arr));
        // print time
        long TotalTime = System.nanoTime() - StartTime;
        System.out.println("Total elapsed time (milliseconds) " + "is: "
                + TotalTime);
    }
}
}

你是说类c只是实现了插入排序,根本没有快速排序,对吗?

那么原则上,类c可以只是类b,用这一行:

    sort(arr);

改为:

    insertionSort(arr, 0, arr.length);

(然后你会想开始剥离大量的代码—删除从未调用的方法,修改insertionSort方法以假定left0, rightarr.length而不是要求指定它们,将insertionSort方法重命名为sort,等等)

顺便说一下,类c实际上比您已经设法创建的类要容易得多。你可能只是需要睡一会儿。明天早上你就不会有问题了。: -)

创建类c复制类b,然后按如下方式修改:添加实例变量insertionSortCalled,这样就不会在不同的递归调用中多次调用插入排序

 boolean insertionSortCalled= false;

将这部分修改为:insertionSort(arr, 0, arr.length-1);

private static void sortInPlace(int[] arr, int left, int right) {
// OptQSort1:
int size = right - left + 1;
**if (size < 10 && !insertionSortCalled){**
      **insertionSortCalled=true;**
      **insertionSort(arr, 0, arr.length-1);**
  }
if (left >= right)
    return; // sorted
final int range = right - left + 1;
int pivot = random.nextInt(range) + left;
int newPivot = partition(arr, left, right, pivot);
sortInPlace(arr, left, newPivot - 1);
sortInPlace(arr, newPivot + 1, right);

}

最新更新