https://codeforces.com/contest/1923/problem/E
题意
给你一棵树,它由 $n$ 个顶点组成,编号从 $1$ 到 $n$。每个顶点都有某种颜色 $a_i$ 满足 $1 \le a_i \le n$。
如果符合以下条件,那么这棵树的一条简单路径就叫做美丽路径:
- 至少由 $2$ 个顶点组成;
- 路径的第一个顶点和最后一个顶点的颜色相同;
- 路径上没有其他顶点的颜色与第一个顶点相同。
计算这棵树的美丽简单路径的数量。请注意,路径是不定向的(即从 $x$ 到 $y$ 的路径与从 $y$ 到 $x$ 的路径算作一条路径)。
题解
My solution
树上路径问题,我们考虑点分治。
关于点分治
考虑分治到点 $x$,路径分为经过 $x$ 的路径与不经过 $x$ 的路径,后者我们会分治到 $x$ 的子树时处理,我们现在考虑如何求解前者。
计算前者,我们遍历所有子树,分别记录 $mp_{v}$ 表示满足颜色为 $v$,且路径 $(x, fa_y)$ 上任意点颜色不等于 $a_y$ 的路径 $(x,y)$ 的数量。实现的话可以在遍历子树时,记录从 $x$ 到当前节点的每个颜色出现的次数。最后将所有子树的每“半条路径”拼接起来,计算答案。
// LUOGU_RID: 148821060
#include <iostream>
#include <vector>
#pragma GCC optimize ("Ofast")
#pragma GCC optimize (3)
#define inf 1000000000
using namespace std;
typedef long long ll;
const int maxn = 200010;
int T;
int n, a[maxn], cnt, h[maxn];
int s[maxn], ms[maxn];
int sum, rt;
bool vis[maxn];
int b[maxn], c[maxn];
ll ans;
vector <int> t, _t;
struct E {
int to, ne;
} e[maxn << 1];
inline void add (int u, int v) {
e[++cnt].to = v; e[cnt].ne = h[u]; h[u] = cnt;
}
void find (int x, int fa) { //寻找重心并计算子树大小
ms = 0;
s = 1;
for (int i = h; i; i = e[i].ne) {
int y = e[i].to;
if (!vis[y] && y != fa) {
find (y, x);
s += s[y];
ms = max (ms, s[y]);
}
}
ms = max (ms, sum - s);
if (ms < ms[rt]) rt = x;
}
void getinfo (int x, int fa) { //遍历子树寻找“半条路径”
if (!b[a]) t.push_back (a);//一条合法的“半条路径”
b[a]++; //记录颜色出现次数
for (int i = h; i; i = e[i].ne) {
int y = e[i].to;
if (!vis[y] && y != fa) {
getinfo (y, x);
}
}
b[a]--; //记录颜色出现次数
}
void dfs (int x, int fa) {
vis = true;
for (int i = h; i; i = e[i].ne) {
int y = e[i].to;
if (y != fa && !vis[y]) {
getinfo (y, x);
for (int d : t) {
if (d == a) ans++;
else ans += c[d];
} //合并“半条路径”
for (int d : t) {
if (c[d]++ == 0) _t.push_back (d);
}
}
t.clear ();
}
for (int x : _t) c = 0; //c相当于mp,不能暴力清空,因此需要记下每次所更改的位置
_t.clear ();
for (int i = h; i; i = e[i].ne) {
int y = e[i].to;
if (y != fa && !vis[y]) {
sum = s[y];
ms[rt = 0] = inf;
find (y, x);
find (rt, 0);
dfs (rt, x);//继续分治
}
}
}
int main () {
ios::sync_with_stdio (false);
cin.tie (0);
cin >> T;
while (T--) {
cin >> n;
fill (vis + 1, vis + n + 1, false);
fill (h + 1, h + n + 1, 0);
fill (c + 1, c + n + 1, 0);
ans = cnt = 0;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add (u, v); add (v, u);
}
ms[rt = 0] = inf;
sum = n;
find (1, 0);
find (rt, 0);
dfs (rt, 0);
cout << ans << '\n';
}
return 0;
}