CF809E Surprise me!

比较套路的莫比乌斯反演,以及虚树处理。

题意:给定一棵 $n$ 点的树,边权均为 1,每个点有一个权值 $a_i$ ,求:

$n\leq 2\times 10 ^ 5$,保证 $a$ 是一个 $1\sim n$ 的排列。

首先一个重要的套路是数论函数内尽量不要有多个变量,尽量转化为单变量

观察 $\varphi(n)$ 的定义式,容易知道这是 $n\prod_{p | n}(1 - \dfrac 1p)$,$p$ 是质数。先把 $(i, j)$ 连边直接连 $(a_i, a_j)$,这样就不用考虑编号问题了。如果 $i, j$ 同时有 $p$ 这个质因子,会乘两道 $1 - \dfrac 1p$,所以需要 $\gcd(i, j)$ 除掉一个。再自己凑一下就可以得到:

带入原式,看到 $\gcd$,直接暴力莫比乌斯反演:

然后我们可以枚举 $T$,这样有用的节点就只有 $\dfrac nT$ 个了。我们相当于是求:

对这些节点建虚树计算,即可保证复杂度。具体的,直接考虑树形 DP,维护子树内的 $\varphi$ 和,以及子树内到他的距离乘 $\varphi$ 的和。这样就可以换根 DP 计算。

由于建虚树需要 $O(m\log m)$,$m$ 为总点数,那么最后复杂度为 $O(n\log ^ 2 n)$,可以通过。

坑点:注意虚树上有些节点是不能计算权值的,某些不是 $T$ 的节点应该把 $v$ 设为 0。虚树上的边权不为 1,转移略麻烦。注意清空,老问题了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
namespace SolvingLCA {
int st[N << 1][19], lg[N << 1], fi[N], cnt;

void dfs(int x, int fa = 0)
{
st[++ cnt][0] = x, fi[x] = cnt;
for (int i = h1[x], v; ~i; i = ne[i])
if ((v = e[i]) ^ fa) dfs(v, x), st[++ cnt][0] = x;
}

inline int dmin(int x, int y) { return fi[x] < fi[y] ? x : y; }

void prework()
{
dfs(1);
for (int i = 2; i <= cnt; ++ i) lg[i] = lg[i >> 1] + 1;
for (int j = 1; j <= lg[cnt]; ++ j)
for (int i = 1; i + (1 << j) - 1 <= cnt; ++ i)
st[i][j] = dmin(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}

int LCA(int x, int y) {
if (fi[x] > fi[y]) std::swap(x, y);
int k = lg[fi[y] - fi[x] + 1];
return dmin(st[fi[x]][k], st[fi[y] - (1 << k) + 1][k]);
}
}
using SolvingLCA::LCA;

inline int& adj(int &x) { return x += x >> 31 & Mod; }

int qpow(int a, int k = Mod - 2)
{
int res = 1;
for (; k; k >>= 1, a = (LL) a * a % Mod)
(k & 1) && (res = (LL) res * a % Mod);
return res;
}

void add(int *h, int a, int b, int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}
void link(int *h, int a, int b, int c = 1) { add(h, a, b, c), add(h, b, a, c); }

void sieve()
{
st[1] = 1, phi[1] = mu[1] = 1;
for (int i = 2; i < N; ++ i)
{
if (!st[i]) prime[cnt ++] = i, phi[i] = i - 1, mu[i] = Mod - 1;
for (int j = 0; j < cnt && i * prime[j] < N; ++ j)
{
st[i * prime[j]] = true;
if (i % prime[j] == 0) {
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
phi[i * prime[j]] = phi[i] * (prime[j] - 1);
mu[i * prime[j]] = Mod - mu[i];
}
}
}

void dfs(int x, int fa = 0)
{
dfn[x] = ++ *dfn, dep[x] = dep[fa] + 1;
for (int i = h1[x], v; ~i; i = ne[i])
if ((v = e[i]) != fa) dfs(v, x);
}

void insert(int x)
{
if (!top) return void(stk[top = 1] = x);
int lca = LCA(stk[top], x);
while (top > 1 && dep[stk[top - 1]] > dep[lca])
link(h2, stk[top - 1], stk[top], dep[stk[top]] - dep[stk[top - 1]]), top --;
if (dep[stk[top]] > dep[lca]) link(h2, lca, stk[top], dep[stk[top]] - dep[lca]), top --;
if (lca != stk[top]) stk[++ top] = lca, extranodes.push_back(lca);
stk[++ top] = x;
}

void dfs1(int x, int fa = 0)
{
sz[x] = val[x], dis[x] = 0;
for (int i = h2[x], v; ~i; i = ne[i])
if ((v = e[i]) != fa)
dfs1(v, x), adj(sz[x] += sz[v] - Mod),
dis[x] = (dis[x] + dis[v] + (LL) sz[v] * w[i]) % Mod;
}

void dfs2(int x, int fa = 0, int prew = 0)
{
if (x != 1)
tdis[x] = (tdis[fa] + (sz[1] - 2LL * sz[x] + Mod) * prew) % Mod;
for (int i = h2[x], v; ~i; i = ne[i])
if ((v = e[i]) ^ fa) dfs2(v, x, w[i]);
}

int main()
{
memset(h1, -1, sizeof(h1));
memset(h2, -1, sizeof(h2));
sieve();
std::cin >> n;
for (int i = 1; i <= n; ++ i) scanf("%d", a + i);
for (int i = 1, u, v; i < n; ++ i)
{
scanf("%d %d", &u, &v), u = a[u], v = a[v];
link(h1, u, v);
}
dfs(1), SolvingLCA::prework();
for (int d = 1; d <= n; ++ d)
for (int i = 1; i * d <= n; ++ i)
mul[i * d] = (mul[i * d] + (LL) d * mu[i] % Mod * qpow(phi[d])) % Mod;
int res = 0, frmidx = idx;
for (int d = 1; d <= n; ++ d)
{
allnodes.clear(), extranodes.clear();
for (int i = d; i <= n; i += d)
allnodes.push_back(i), val[i] = phi[i];
extranodes = allnodes;
std::sort(allnodes.begin(), allnodes.end(), [&](int x, int y) {
return dfn[x] < dfn[y];
});
if (d ^ 1) stk[top = 1] = 1, extranodes.push_back(1);
for (int x : allnodes) insert(x);
while (-- top) link(h2, stk[top], stk[top + 1], dep[stk[top + 1]] - dep[stk[top]]);
dfs1(1), tdis[1] = dis[1], dfs2(1);
int cur = 0;
for (int x : allnodes)
cur = (cur + (LL) phi[x] * tdis[x]) % Mod;
res = (res + mul[d] * (LL) cur) % Mod;
for (int x : allnodes) h2[x] = -1, val[x] = 0;
for (int x : extranodes) h2[x] = -1;
idx = frmidx;
}
res = (LL) res * qpow(n * (n - 1LL) % Mod) % Mod;
std::cout << res << '\n';
return 0;
}