最近在读吴军老师的《计算之魂》,这是我读过唯一一本没有代码参考的算法书,也许以后会专门出一本《代码之魂》吧。
书中1.3 怎样寻找最好的算法,列举了使用不同方法求一个数列总和最大区间的不同算法,使得时间复杂度从O(N^3)降到O(N),书中的序列如下
1.5, -12.3, 3.2, -5.5, 23.2, 3.2, -1.4, -12.2, 34.2, 5.4, -7.8, 1.1, -4.9正确答案是:23.2, 3.2, -1.4, -12.2, 34.2, 5.4
1. 三重循环(O(N^3))
最直接粗暴的方法是利用三重循环,将所有序列一一组合,得到总和最大的区间,按照书中的思路,java代码的实现如下:
double[] a = {1.5, -12.3, 3.2, -5.5, 23.2, 3.2, -1.4, -12.2, 34.2, 5.4, -7.8, 1.1, -4.9};
int n = a.length;
double max = Double.MIN_VALUE;
int start = 0;
int end = 0;
for (int i = 0; i < n; i++) {
    for (int j = i; j < n; j++) {
        double sum = 0;
        for (int k = i; k <= j; k++) {
            sum += a[k];
        }
        if (sum > max) {
            max = sum;
            start = i;
            end = j;
        }
    }
}
System.out.println("Start: " + start);
System.out.println("End: " + end);2. 基于方法1的优化
方法2可以去掉一层循环,是时间复杂度变为O(N^2)
double[] a = {1.5, -12.3, 3.2, -5.5, 23.2, 3.2, -1.4, -12.2, 34.2, 5.4, -7.8, 1.1, -4.9};
int n = a.length;
double max = Double.MIN_VALUE;
int start = 0;
int end = 0;
for (int i = 0; i < n; i++) {
    double sum = 0;
    for (int j = i; j < n; j++) {
        sum += a[j];
        if (sum > max) {
            max = sum;
            start = i;
            end = j;
        }
    }
}
System.out.println("Start: " + start);
System.out.println("End: " + end);两段代码的比较如下图:

第三重循环的sum完全可以使用第二重循环序列的和相加,而不需要再循环一次,浪费计算资源。
书中描述:求p,q之间的总和S,S=(p,q),如果想要得到S=(p,q+1)的和,只需要在原来的基础上再做一次加法即可。
方法二如果数组a为好几万个数字的组合,那么计算量是十几亿,但比方法一节约了上万倍的计算量。
3. 使用分治算法
将序列一分为二,分成从1到N/2和N/2+1到N,然后分别求它们的总和最大区间。代码如下:
double[] a = {1.5, -12.3, 3.2, -5.5, 23.2, 3.2, -1.4, -12.2, 34.2, 5.4, -7.8, 1.1, -4.9};
        int n = a.length;
        int[] start = {0};
        int[] end = {0};
        double max = maxSum(a, 0, n - 1, start, end);
        System.out.println("Start: " + start[0]);
        System.out.println("End: " + end[0]);
        System.out.println("Max sum: " + max);
        System.out.println("Max sum subarray: " + Arrays.toString(Arrays.copyOfRange(a, start[0], end[0] + 1)));
    }
    static double maxSum(double[] a, int l, int r, int[] start, int[] end) {
        if (l == r) {
            start[0] = end[0] = l;
            return a[l];
        }
        int mid = (l + r) / 2;
        int s1 = mid;
        double sum = 0;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = mid; i >= l; i--) {
            sum += a[i];
            if (sum > max) {
                max = sum;
                s1 = i;
            }
        }
        sum = 0;
        int s2 = mid + 1;
        int e2 = mid + 1;
        for (int i = mid + 1; i <= r; i++) {
            sum += a[i];
            if (sum > max) {
                max = sum;
                e2 = i;
            }
        }
        double leftMax = maxSum(a, l, mid, start, end);
        double rightMax = maxSum(a, mid + 1, r, start, end);
        if (leftMax > max) {
            max = leftMax;
        }
        if (rightMax > max) {
            max = rightMax;
            start[0] = s2;
            end[0] = e2;
        }
        start[0] = s1;
        end[0] = e2;
        return max;
    }