Saturday, October 15, 2016

Count of Range Sum

Given an integer array nums, return the number of range sums that lie in [lower, upper] inclusive.
Range sum S(i, j) is defined as the sum of the elements in nums between indices i and j (i ≤ j), inclusive.
Note:
A naive algorithm of O(n2) is trivial. You MUST do better than that.
Example:
Given nums = [-2, 5, -1]lower = -2upper = 2,
Return 3.
The three ranges are : [0, 0][2, 2][0, 2] and their respective sums are: -2, -1, 2.

Besides brutal force approach, the easiest way to understand that I can find is to use binary search. For a given array of sums, we can evenly divide the array by two, so the range sum exists in three places, left side, right side and starts from left and ends to right. For the first two, we can recursively call countSum function. For the middle one, we use the similar approach as we do in naive approach. The complexity is close to O(logn).


public int countRangeSum(int[] nums, int lower, int upper) {
        int len = nums.length;
        if (lower > upper || len <= 0) {
            return 0;
        }
        long[] sums = new long[len];
        sums[0] = nums[0];
        for (int i = 1; i < len; i++) {
            sums[i] = sums[i - 1] + nums[i];
        }
        return getCount(nums, 0, len - 1, sums, lower, upper);
    }
    
    private int getCount(int[] nums, int left, int right, long[] sums, int lower, int upper) {
        if (left > right) {
            return 0;
        }
        if (left == right) {
            if (nums[left] <= upper && nums[left] >= lower) {
                return 1;
            } else {
                return 0;
            }
        }
        int mid = (left + right) / 2;
        int count = 0;
        for (int i = left; i <= mid; i++) {
            for (int j = mid + 1; j <= right; j++) {
                //sums[i] = 0 to i sum, inclusive
                // sums[j] = 0 to j sum, inclusive
                //sums[j] - sums[i] i + 1 to j sum
                //so we need to add nums[i] back
                long tmp = sums[j] - sums[i] + nums[i];
                if ( tmp <= upper && tmp >= lower) {
                    count++;
                }
            }
        }
        return count + getCount(nums, left, mid, sums, lower, upper) + getCount(nums, mid + 1, right, sums, lower, upper);
    }


No comments:

Post a Comment