Google — Count Inversions (Modified Merge Sort)
Advertisement
Problem (Google Deep Algo)
Count the number of inversions in an array — pairs (i, j) where i < j but nums[i] > nums[j].
Example:
nums = [2, 4, 1, 3, 5]
Inversions: (2,1),(4,1),(4,3) → 3
Key Insight — Modified Merge Sort
During merge, when we pick from the right half, all remaining elements in the left half are inversions with it.
Merging [2,4] and [1,3]:
Pick 1 from right → 2 elements remain in left → count += 2
Pick 2 from left
Pick 3 from right → 1 element remains in left → count += 1
Pick 4 from left
Total additions: 3 ✓
Solutions
Python
def countInversions(nums):
def merge_sort(arr):
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, lc = merge_sort(arr[:mid])
right, rc = merge_sort(arr[mid:])
merged, mc = merge(left, right)
return merged, lc + rc + mc
def merge(left, right):
result, count = [], 0
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i]); i += 1
else:
result.append(right[j])
count += len(left) - i
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result, count
_, total = merge_sort(nums)
return total
JavaScript
function countInversions(nums) {
let count = 0;
const ms = arr => {
if (arr.length<=1) return arr;
const mid=arr.length>>1, L=ms(arr.slice(0,mid)), R=ms(arr.slice(mid));
const res=[]; let i=0,j=0;
while(i<L.length&&j<R.length){
if(L[i]<=R[j]) res.push(L[i++]);
else{res.push(R[j++]);count+=L.length-i;}
}
return res.concat(L.slice(i)).concat(R.slice(j));
};
ms(nums);
return count;
}
Java
long mergeSort(int[] nums, int[] tmp, int l, int r) {
if (l>=r) return 0;
int mid=(l+r)/2;
long cnt=mergeSort(nums,tmp,l,mid)+mergeSort(nums,tmp,mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(nums[i]<=nums[j]) tmp[k++]=nums[i++];
else{tmp[k++]=nums[j++];cnt+=mid-i+1;}
}
while(i<=mid) tmp[k++]=nums[i++];
while(j<=r) tmp[k++]=nums[j++];
System.arraycopy(tmp,l,nums,l,r-l+1);
return cnt;
}
C++
#include <vector>
using namespace std;
long long mergeCount(vector<int>& a, int l, int r) {
if (r-l<=1) return 0;
int mid=(l+r)/2;
long long cnt=mergeCount(a,l,mid)+mergeCount(a,mid,r);
vector<int> tmp; int i=l,j=mid;
while(i<mid&&j<r){if(a[i]<=a[j])tmp.push_back(a[i++]);else{tmp.push_back(a[j++]);cnt+=mid-i;}}
while(i<mid)tmp.push_back(a[i++]);
while(j<r)tmp.push_back(a[j++]);
copy(tmp.begin(),tmp.end(),a.begin()+l);
return cnt;
}
C
long long merge_count(int* arr, int* tmp, int l, int r) {
if(r-l<=1) return 0;
int mid=(l+r)/2; long long cnt=merge_count(arr,tmp,l,mid)+merge_count(arr,tmp,mid,r);
int i=l,j=mid,k=l;
while(i<mid&&j<r){if(arr[i]<=arr[j])tmp[k++]=arr[i++];else{tmp[k++]=arr[j++];cnt+=mid-i;}}
while(i<mid)tmp[k++]=arr[i++]; while(j<r)tmp[k++]=arr[j++];
for(int x=l;x<r;x++)arr[x]=tmp[x];
return cnt;
}
Complexity
| Approach | Time | Space |
|---|---|---|
| Merge sort | O(n log n) | O(n) |
| Naive brute force | O(n²) | O(1) |
Advertisement