AHOI2022 钥匙

码量题,vp 时差点写吐了(当然也和我选的方法不够简洁有关),没调出来。

题意:给定一棵树,每个节点有钥匙或宝箱,钥匙和宝箱都有一个颜色,相同颜色 的才能匹配。同一种颜色的钥匙最多只有 5 把。进行 $q$ 次旅行,问每次旅行能打开多少宝箱。$n\leq 5\times 10 ^ 5, q\leq 10 ^ 6$。

容易发现我们要对每一种颜色建一棵虚树。

建虚树过后,由于钥匙最多只有 5 把,暴力以为一个钥匙为起点枚举是可行的。考虑计算到每一个节点时还剩几把钥匙,如果没有了说明起点的钥匙会和这个点匹配,当他们两同时出现在路径上时,贡献会多 1。

最后就是一个路径覆盖问题,注意 $(a, b)$ 需要分类 $a$ 是 $b$ 的祖先,$b$ 是 $a$ 的祖先,其余情况。拍到 dfn 序上,扫描线 + 树状数组即可解决。

代码是在考场代码上改的,很冗长,仅供参考。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
struct Tree {
int h[N], e[M], ne[M], idx;
int st[N << 1][22], lg[N << 1], fi[N], cnt;
int dep[N], d1[N], d2[N], typ[N], f[N], up[N], frm, rem[N];
std::vector<int> ans;
void init(int n)
{
for (int i = 1; i <= n; ++ i) h[i] = -1;
for (int i = 1; i <= n; ++ i) typ[i] = 0;
idx = cnt = frm = 0;
}
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx ++; }
void link(int a, int b) { add(a, b), add(b, a); }

void dfs(int x, int fa = 0)
{
st[++ cnt][0] = x, fi[x] = cnt, dep[x] = dep[fa] + 1;
d1[x] = d1[fa], d2[x] = d2[fa], f[x] = fa, up[x] = frm;
if (typ[x] == 1) d1[x] ++, frm = x;
else if (typ[x] == 2) d2[x] ++;
int tmp = frm;
for (int i = h[x]; ~i; i = ne[i])
if (e[i] ^ fa) dfs(e[i], x), st[++ cnt][0] = x, frm = tmp;
}

inline int dmin(int x, int y) { return dep[x] < dep[y] ? x : y; }

void prework()
{
dfs(1);
for (int i = 2; i <= cnt; ++ i) lg[i] = lg[i >> 1] + 1;
for (int j = 1; j <= lg[cnt]; ++ j)
for (int i = 1; i + (1 << j) - 1 <= cnt; ++ i)
st[i][j] = dmin(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}

int LCA(int x, int y) {
if (fi[x] > fi[y]) std::swap(x, y);
int k = lg[fi[y] - fi[x] + 1];
return dmin(st[fi[x]][k], st[fi[y] - (1 << k) + 1][k]);
}

int dist1(int x, int y) {
int t = LCA(x, y);
return d1[x] + d1[y] - d1[t] - d1[f[t]];
}
int dist2(int x, int y) {
int t = LCA(x, y);
return d2[x] + d2[y] - d2[t] - d2[f[t]];
}

void dfs2(int x, int fa = 0)
{
rem[x] = rem[fa];
if (typ[x] == 1) rem[x] ++;
if (typ[x] == 2) rem[x] --;
if (rem[x] == 0) return ans.push_back(x);
for (int i = h[x]; ~i; i = ne[i])
if (e[i] ^ fa) dfs2(e[i], x);
}

std::vector<int> getmatch(int x) {
ans.clear(), dfs2(x);
return ans;
}
} tr, oc;

struct BIT {
int tr[N];
void add(int x, int c) { for (int i = x; i < N; i += (i & -i)) tr[i] += c; }
int ask(int x) {
int res = 0;
for (int i = x; i; i ^= (i & -i)) res += tr[i];
return res;
}
} bt;

void insert(int x)
{
if (!top) return void(stk[top = 1] = x);
int lca = tr.LCA(x, stk[top]);
allnodes.push_back(lca);
if (!pos[lca]) pos[lca] = ++ curcnt;
while (top > 1 && tr.dep[stk[top - 1]] >= tr.dep[lca])
oc.link(pos[stk[top - 1]], pos[stk[top]]), top --;
if (stk[top] != lca) oc.link(pos[lca], pos[stk[top]]), stk[top] = lca;
if (lca != x) stk[++ top] = x;
}

void dfs(int x, int fa = 0)
{
dfn[x] = ++ *dfn, sz[x] = 1, nw[*dfn] = x;
for (int v : g[x])
if (v ^ fa) dfs(v, x), sz[x] += sz[v];
}

int main()
{
#ifndef MyRun
// freopen("keys.in", "r", stdin);
// freopen("keys.out", "w", stdout);
#endif
read(n, Q);
for (int i = 1; i <= n; ++ i) read(typ[i], col[i]);
tr.init(n);
for (int i = 1, u, v; i < n; ++ i)
read(u, v), g[u].push_back(v), g[v].push_back(u), tr.add(u, v), tr.add(v, u);
tr.prework();
for (int i = 1; i <= n; ++ i) {
all[col[i]].push_back(i);
if (typ[i] == 1) ky[col[i]].push_back(i);
else bx[col[i]].push_back(i);
}
dfs(1);
auto cmp = [&](int x, int y) { return dfn[x] < dfn[y]; };
for (int i = 1; i <= n; ++ i)
std::sort(all[i].begin(), all[i].end(), cmp);
for (int i = 1; i <= n; ++ i) {
if (!ky[i].size()) continue;
// std::cerr << "Col " << i << std::endl;
oc.init(std::min((int) all[i].size() * 2 + 1, n + 1));
for (int j = 0; j < (int) all[i].size(); ++ j)
pos[all[i][j]] = j + 2;
// for (int i = 1; i <= n; ++ i) printf("%d ", pos[i]);
// puts("");
allnodes = all[i], allnodes.push_back(1);
pos[1] = 1, insert(1), curcnt = all[i].size() + 1;
// printf("%d %d %d %d\n", i, (int) all[i].size(), (int) ky[i].size(), (int) allnodes.size());
for (int x : all[i])
if (x ^ 1) insert(x);
while (-- top) oc.link(pos[stk[top]], pos[stk[top + 1]]);
for (int x : ky[i]) oc.typ[pos[x]] = 1;
for (int x : bx[i]) oc.typ[pos[x]] = 2;
oc.prework();
for (int p1 : ky[i]) {
auto mat = oc.getmatch(pos[p1]);
for (int p2 : mat) {
if (p2 > (int) all[i].size() + 1) continue;
if (p2 == 1 && bx[i].front() != 1) continue;
if (p2 == 1) p2 ++;
p2 = all[i][p2 - 2];
// printf("%d %d Pair\n", p1, p2);
if (dfn[p1] <= dfn[p2] && dfn[p2] < dfn[p1] + sz[p1]) {
int son = -1;
for (int v : g[p1])
if (dfn[v] > dfn[p1] && dfn[v] <= dfn[p2] && dfn[p2] < dfn[v] + sz[v])
assert(!~son), son = v;
assert(~son);
opt[1].push_back({dfn[p2], dfn[p2] + sz[p2] - 1, 1});
opt[dfn[son]].push_back({dfn[p2], dfn[p2] + sz[p2] - 1, -1});
opt[dfn[son] + sz[son]].push_back({dfn[p2], dfn[p2] + sz[p2] - 1, 1});
} else if (dfn[p2] <= dfn[p1] && dfn[p1] < dfn[p2] + sz[p2]) {
int son = -1;
for (int v : g[p2])
if (dfn[v] > dfn[p2] && dfn[v] <= dfn[p1] && dfn[p1] < dfn[v] + sz[v])
assert(!~son), son = v;
assert(~son);
opt[dfn[p1]].push_back({1, dfn[son] - 1, 1});
opt[dfn[p1]].push_back({dfn[son] + sz[son], n, 1});
opt[dfn[p1] + sz[p1]].push_back({1, dfn[son] - 1, -1});
opt[dfn[p1] + sz[p1]].push_back({dfn[son] + sz[son], n, -1});
} else
opt[dfn[p1]].push_back({dfn[p2], dfn[p2] + sz[p2] - 1, 1}),
opt[dfn[p1] + sz[p1]].push_back({dfn[p2], dfn[p2] + sz[p2] - 1, -1});
}
}

for (int x : allnodes) pos[x] = 0;
}
// std::cerr << "Success" << std::endl;
for (int i = 1, st, ed; i <= Q; ++ i) {
read(st, ed);
assert(dfn[st] >= 1 && dfn[st] <= n);
q[dfn[st]].push_back({dfn[ed], i});
}
for (int i = 1; i <= n; ++ i)
{
for (auto p : opt[i])
bt.add(p.l, p.c), bt.add(p.r + 1, -p.c);
for (auto p : q[i])
res[p.second] = bt.ask(p.first);
}
for (int i = 1; i <= Q; ++ i) printf("%d\n", res[i]);
return 0;
}