前情提要
一个朋友询问我一道某大厂的面试题,看到题目的第一眼想到了 $\mathrm O(n\times k)$ 的动态规划算法来解决,朋友反馈说记得 $n$ 和 $k$ 至少 $10^5$,后来发现只要维护最后 $k$ 个元素的最大值即可,直接塞到 set
里面就能优化到 $\mathrm O(n\log n)$ 或 $\mathrm O(n\log k)$,应该可以 cover 掉数据范围,朋友也对复杂度满意了。但我总觉得可以 $\mathrm O(n)$ 线性解决,在为朋友写代码时,突然隐隐约约觉得单调队列似乎就是这样子,去 OI-Wiki 查询果然是这样,遂写此篇题解博客记录不断思考优化算法复杂度的心路历程。
题目大意
给定长度为 $n$ 一个数组,有正有负的数。问你在其中取哪些数,可以使得这些数的和最大。但是条件是,每 $k$ 个连续的数,都至少要取一个出来。输入是$n$、$k$、数组;输出是最大的这个和。
样例
输入
5 3
-4 -100 -9 -100 -4
输出
-9
题解
令 $dp[i]$ 表示取了第 $i$ 个数的情况下,取到的数的最大值。则状态转移方程为:
$$
dp[i] = \max_{j \in [1, \min{(k, i)}]}(dp[i - j]) + arr[i]
$$
最终答案为:
$$
\max_{j\in[0, k - 1]}(dp[n - j])
$$
直接按照数学公式计算即可得到答案,时间复杂度为 $\mathrm O(n\times k)$。
我们注意到,我们需要一直维护数组中最后 $k$ 个元素的最大值,这些最大值具有连续性($i$ 每次增加,都需要删掉一个元素,然后增加一个新元素),因此可以使用增加、删除以及查询最大值复杂度均为 $\mathrm O(\log n)$ 的数据结构来维护。
在 C++
中我们可以使用 STL
中基于红黑树的 set
容器来维护,因为容器中最多有 $k$ 个元素,所以时间复杂度为 $\mathrm O(\log k)$;同理,在 Python
中我们可以使用基于小根堆的 heapq
优先队列来维护,因为队列中最多有 $n$ 个元素(元素均为非负数),所以时间复杂度为 $\mathrm O(\log n)$。
我们注意到,在优先队列中,其实某些元素永远不会出队,而这些元素共同的性质便是比队头早入队 $k$ 个或以上且比队头的元素小。那么我们能不能在当前队头元素入队时,将上述元素都踢出队列呢?优先队列显然是做不到的:因为上述元素一般集中在堆底。但如果我们尝试换成普通的队列呢?先不管如何查询最值,每次遇到一个元素,我们如果发现队尾元素比该元素更小,那么其实队尾元素在未来永远也用不到了(因为会比该元素早离开最后 $k$ 个元素的范围),不断重复这个过程直到遇到一个比该元素大的队尾,此时再将该元素插入队尾——这时候让我们回过头来看,会惊奇地发现,整个队列中的元素竟然是单调递减的!那么想要查询最大值只要从队头开始找属于最后 $k$ 个元素范围内的第一个元素即可。与此同时我们注意到,如果队头元素已经超出最后 $k$ 个元素的范围,那么该元素未来也不可能会用到了,所以同样可以将符合这一条件的队头元素踢出队列。于是乎,我们总是可以直接查询队头元素来得到当前最后 $k$ 个元素的最大值。因为这个过程中,每个元素只进队和出队一次,而查询最值也是 $\mathrm O(1)$ 的复杂度,因此总时间复杂度为 $\mathrm O(n)$。这种允许在一端插入元素、两端都删除元素的数据结构叫做双端队列 deque
,在 Python
中位于标准库 collections
中,在 C++
中位于 STL
中。而本题中的双端队列永远保持单调递减,所以全程都具有这种单调性质的队列叫做“单调队列”,因而这种对于动态规划的优化方法叫做“单调队列优化”。
代码
$\mathrm O(n\times k)$ 解法:
def solve(arr, n, k): dp = [0] for i in range(1, n + 1): dp.append(max(dp[i - min(k, i):]) + arr[i - 1]) # 也可以像下面这样写: # dp.append(dp[i - 1]) # for j in range(2, min(k, i) + 1): # dp[i] = max(dp[i], dp[i - j]) # dp[i] += arr[i - 1] return max(dp[n - min(k, n - 1):]) if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
$\mathrm O(n\times \log n)$ 解法:
import heapq def solve(arr, n, k): Q = [(0, 0)] for i in range(1, n + 1): heapq.heappush(Q, (Q[0][0] - arr[i - 1], i)) while Q[0][1] <= i - k: heapq.heappop(Q) return -Q[0][0] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
$\mathrm O(n)$ 解法:
保留 $dp$ 数组:from collections import deque def solve(arr, n, k): dp = [0] q = deque([0]) for i in range(1, n + 1): dp.append(dp[q[0]] + arr[i - 1]) while q and dp[q[-1]] <= dp[i]: q.pop() # 将比当前值小的元素全部弹出 q.append(i) while q[0] <= i - k: q.popleft() # 判断队首元素是否在窗口内 return dp[q[0]] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))
不保留 $dp$ 数组:
from collections import deque def solve(arr, n, k): q = deque([(0, 0)]) for i in range(1, n + 1): t = q[0][0] + arr[i - 1] while q and q[-1][0] <= t: q.pop() # 将比当前值小的元素全部弹出 q.append((t, i)) while q[0][1] <= i - k: q.popleft() # 判断队首元素是否在窗口内 return q[0][0] if __name__ == '__main__': n, k = map(int, input().split()) arr = list(map(int, input().split())) print(solve(arr, n, k))