引入
在处理区间问题时,我们常常用分治思想,将区间分为若干小区间,然后对小区间进行区间统计,最后进行区间合并,但是当区间信息不方便合并时,我们可以使用莫队。
详解
例如对于小B的询问。
题意:
- 小B有一个长为 $n$ 的整数序列 $a$,值域为 $[1,k]$。
他一共有 $m$ 个询问,每个询问给定一个区间 $[l,r]$,求: $\sum\limits_{i=1}^k c_i^2$。
其中 $c_i$ 表示数字 $i$ 在 $[l,r]$ 中的出现次数。
小B请你帮助他回答询问。
分析
对于这道题,我们发现它有几个性质:
1.不方便进行区间信息合并
- 数据范围较小,但可以卡掉 $O(n^2)$
- 根据一个已知答案的区间 $[l, r]$,可以快速求出区间 $[l, r + 1]$ 的答案。
- 对于性质1,我们无法用线段树,前缀,st表等结构来维护,例如线段树的统计信息无法维护。
- 对于性质3,具体这样实现的。比如我们增加一个元素 $v$ 进来,相当于让答案从 $c_v^2+...$ 变成了$(c_v+1)^2+...$,那么相当于答案增加 $2 \times c_v+1$,减少元素也类似。
于是我们得到一个暴力算法,我们对于第一个查询暴力算出来答案和 $c$ 数组,然后增加或减少若干元素并维护信息,求出第二个查询,然后在增加减少若干个元素并维护信息,求出第三个查询 $...$。
代码:
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
using namespace std;
const int maxn = 50010;
int n, m, m_, k, d, c[maxn], a[maxn];
long long ans, res[maxn];
struct E {
int l, r, id;
};
vector<E> q;
bool cmp(E x, E y) {
if (x.l / d != y.l / d) return x.l / d < y.l / d;
if (x.id % 2) return x.r > y.r;
return x.r < y.r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m >> k;
d = sqrt(n * n / m);
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= m; i++) {
E t;
cin >> t.l >> t.r;
t.id = i;
q.push_back(t);
}
sort(q.begin(), q.end(), cmp);
for (int i = q[0].l; i <= q[0].r; i++) c[a[i]]++;
for (int i = 1; i <= k; i++) ans += c[i] * c[i];
res[q[0].id] = ans; // 暴力算第一个
int l = q[0].l, r = q[0].r;
for (int i = 1; i < m; i++) {
while (l > q[i].l) ans += (c[a[--l]]++) * 2, ans++;
while (r < q[i].r) ans += (c[a[++r]]++) * 2, ans++;
while (l < q[i].l) ans -= (c[a[l++]]--) * 2, ans++;
while (r > q[i].r) ans -= (c[a[r--]]--) * 2, ans++;
res[q[i].id] = ans;
}
for (int i = 1; i <= m; i++) cout << res[i] << endl;
return 0;
}
但是我们交上去,发现全部TLE了,那是因为 $l$ 和 $r$ 可能会移动很大的范围,导致复杂度上升到 $n^2$。
那么我们是否可以限制他们移动的范围呢?没错,我们将所有询问离线下来,采用分块的思想,将 $n$ 分为 $[1,d],[d+1,2d] ... [pd+1,n]$ 这些区域。
我们将 $l$ 在同一个区域的询问放在一起,然后对于同一个区域的询问按照 $r$ 从小到大排序。这样的话,每个 $l$ 最多只会移动 $d$ 次,所以 $l$ 一共最多会移动 $md$ 次,对于一个区域,$r$ 由于是递增的,也最多只会移动 $n$ 次,而一共有 $n/d$ 个区域,所以 $r$ 最多移动 $n^2/d$ 次。 那么我们要移动的次数最小,就是要 $md+n^2/d$ 最小了因此 $d$ 取 $\sqrt{n^2/m}$ 时时间复杂度最小,大约为 $O(n \times \sqrt{n})$。
#include <algorithm>
#include <cmath>
#include <iostream>
#include <map>
#include <vector>
using namespace std;
const int maxn = 50010;
int n, m, m_, k, d, c[maxn], a[maxn];
long long ans, res[maxn];
struct E {
int l, r, id;
};
vector<E> q;
bool cmp(E x, E y) {
if (x.l / d != y.l / d) return x.l / d < y.l / d;
return x.r < y.r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m >> k;
d = sqrt(n * n / m);
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= m; i++) {
E t;
cin >> t.l >> t.r;
t.id = i;
q.push_back(t);
}
sort(q.begin(), q.end(), cmp);
for (int i = q[0].l; i <= q[0].r; i++) c[a[i]]++;
for (int i = 1; i <= k; i++) ans += c[i] * c[i];
res[q[0].id] = ans; // 暴力算第一个
int l = q[0].l, r = q[0].r;
for (int i = 1; i < m; i++) {
while (l > q[i].l) ans += (c[a[--l]]++) * 2, ans++;
while (r < q[i].r) ans += (c[a[++r]]++) * 2, ans++;
while (l < q[i].l) ans -= (c[a[l++]]--) * 2, ans++;
while (r > q[i].r) ans -= (c[a[r--]]--) * 2, ans++;
res[q[i].id] = ans;
}
for (int i = 1; i <= m; i++) cout << res[i] << endl;
return 0;
}
树上莫队
树上莫队很简单,就是将树通过 $dfs$ 序或者欧拉序转换为序列问题。
例题
题意
给定一棵 $n$ 个节点的树,根节点为 $1$。每个节点上有一个颜色 $c_i$。$m$ 次操作。操作有一种:
u k
:询问在以 $u$ 为根的子树中,出现次数 $\ge k$ 的颜色有多少种。
$2\le n\le 10^5$,$1\le m\le 10^5$,$1\le c_i,k\le 10^5$。
题解
我们直接对这个数做 $dfs$ 序,那么询问 $u$ 就相当于询问 $[dfn_u,dfn_u+size_u-1$ ($dfn_u,size_u$ 分别表示 $u$ 的 $dfs$ 序和子树大小)
#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>
using namespace std;
const int maxn = 500010;
int n, m, d, l, r, tot, siz[maxn], res[maxn];
int v[maxn], s[maxn];
vector<int> g[maxn];
struct E {
int l, r, k, id;
} ask[maxn];
struct Node {
int col, dfn, last;
} a[maxn];
void dfs(int x) {
a.dfn = ++tot;
siz = 1;
for (int y : g) {
if (!a[y].dfn) dfs(y), siz += siz[y];
}
a.last = tot;
return;
}
bool cmp(E x, E y) {
if (x.l / d != y.l / d) return x.l / d < y.l / d;
return x.r < y.r;
}
bool comp(Node x, Node y) {
return x.dfn < y.dfn;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
d = max((int)sqrt(n * n / m), 1);
for (int i = 1; i <= n; i++) cin >> a[i].col;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1);
for (int i = 1; i <= m; i++) {
int x, k;
cin >> x >> k;
ask[i].id = i;
ask[i].l = a.dfn;
ask[i].k = k;
ask[i].r = a.last;
}
sort(ask + 1, ask + m + 1, cmp);
sort(a + 1, a + n + 1, comp);
l = 1;
r = 0;
for (int i = 1; i <= m; i++) {
while (l > ask[i].l) s[++v[a[--l].col]]++;
while (r < ask[i].r) s[++v[a[++r].col]]++;
while (l < ask[i].l) s[v[a[l++].col]--]--;
while (r > ask[i].r) s[v[a[r--].col]--]--;
res[ask[i].id] = s[ask[i].k];
}
for (int i = 1; i <= m; i++) cout << res[i] << '\n';
return 0;
}