算法Java单调栈[leetcode 907][Java]子数组的最小值之和
inkOrCloud题目链接
leetcode - 907
前置概念
单调栈
栈中元素始终保持某种单调性,即从栈底到栈顶使用递增/递减,比如:
将[1, 4, 3, 9, 5]压入单调递增栈
- 压入1,stack:[1]
- 压入4, stack:[4, 1]
- 若直接压入3则无法保持单调性,此时需要不断弹出栈顶直到能够保持单调性,这里应该弹出4,压入3,stack:[3, 1]
- 压入9,stack:[9, 3, 1]
- 同样,若直接压入5则无法保持单调性,需要弹出9再压入,stack:[5, 3, 1]
通过压入单调栈的过程就可以得知数组中的元素A前/后第一个大于/小于A的元素,例如上述示例中,在弹出4时就可以知道元素4往后第一个小于4元素时3,在弹出9时就可以知道第一个小于9的是5
题目解析
需要计算每个子数组最小值的和,那么就可以从获取每个元素作为最小值所在的子数组,再将所有子数组的数量乘对应最小值相加即可得到答案
辐射范围
定义
这里临时定义一个辐射范围,即元素X辐射范围内所有含X的子数组最小值都是X
确定辐射范围
显然要满足辐射范围的唯一条件就是该范围内没有比X更小的元素,这时候就可以用上面说到的单调增栈获取每个元素前/后第一个小于该元素的元素位置,进而在O(n)时间复杂度内确定辐射范围:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| int[] dp1 = new int[len]; int[] dp2 = new int[len]; for(int i = 0; i < len; i++) { dp1[i] = len; dp2[i] = -1; }
Stack<int[]> s1 = new Stack<>(); for(int i = 0; i < len; i++) { while(!s1.empty() && s1.peek()[1] > arr[i]) { int[] cur = s1.pop(); dp1[cur[0]] = i; } s1.push(new int[] {i, arr[i]}); }
Stack<int[]> s2 = new Stack<>(); for(int i = len-1; i >= 0; i--) { while(!s2.empty() && s2.peek()[1] >= arr[i]) { int[] cur = s2.pop(); dp2[cur[0]] = i; } s2.push(new int[] {i, arr[i]}); }
|
通过上面的操作就得知了每个元素的辐射范围(开区间),例如[1,4,3,9,5]中
- 1的辐射范围是(-1, 5)
- 4为(0, 2)
- 3为(0, 5)
- 9为(2, 4)
- 5为(2, 5)
计算子数组数量
得知了每个元素的辐射范围,就可以在O(n)时间复杂度内计算出每个元素作为最小值所在子数组的数量,这里分别设辐射范围的左右边界为l/r,将l/r向该元素X靠拢,每个不同的l/r组合即代表一种最小值为X的子数组,那么子数组的数量就是(ind(x)-l)*(r-ind(x))。
例如4的左右边界分别为[0,2],元素4的下标为1,那么所有以4为最小值的子数组数量就是(1-0)*(2-1)=2
计算结果
得知了每个元素对应的子数组数量,就可以得出结果:
- 1对应的子数组数量为
(0 + 1) * (5 - 0) = 5
- 4为
(1-0)*(2-1)=1
- 3为
(2-0)*(5-2)=6
- 9为
(3-2)*(4-3)=1
- 5为
(4-2)*(5-4)=2
结果为5*1 + 1*4 + 6*3 + 1*9 + 2*5 = 46
为什么极端辐射范围时正序或倒序必须有一个需要严格递增
设数组为[1, 2, 3, 1],若正序和倒序都不是严格递增那么两个1的辐射范围都是[0, 3] (闭区间)。
那么子数组[1, 2, 3, 1]对于两个1来说都在辐射范围内,那么就会对这个子数组进行重复计算,为了避免这种情况就要对左或右边界设为严格递增,即遇到相等的值就算作辐射范围的边界,此时左边的1的辐射范围就是[0, 2],而右边的1的辐射范围依然是[0, 3]
整体代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| public class Solution { public static void main(String[] args) { Solution s = new Solution(); System.out.print(s.sumSubarrayMins(new int[] {1, 4, 3, 9, 5})); } static long MOD = (long)(1e9+7); public int sumSubarrayMins(int[] arr) { int len = arr.length; long ret = 0; int[] dp1 = new int[len]; int[] dp2 = new int[len]; for(int i = 0; i < len; i++) { dp1[i] = len; dp2[i] = -1; } Stack<int[]> s1 = new Stack<>(); for(int i = 0; i < len; i++) { while(!s1.empty() && s1.peek()[1] > arr[i]) { int[] cur = s1.pop(); dp1[cur[0]] = i; } s1.push(new int[] {i, arr[i]}); } Stack<int[]> s2 = new Stack<>(); for(int i = len-1; i >= 0; i--) { while(!s2.empty() && s2.peek()[1] >= arr[i]) { int[] cur = s2.pop(); dp2[cur[0]] = i; } s2.push(new int[] {i, arr[i]}); } for(int i = 0; i < len; i++) { long l = dp1[i], r = dp2[i]; long sum = ((i-l)*(r-i)*arr[i])%MOD; ret = (sum + ret)%MOD; } return (int)ret; } }
|