LOJ2339 「WC2018」通道

题意:给定三棵大小为 $n$ 的树,求 $\min_{i < j} \{\text{dist}_1(i, j) + \text{dist}_2(i, j) + \text{dist}_3(i, j)\}$。$n\leq 10 ^ 5$,$w_i\leq 10 ^ {12}$。

首先考虑两棵树的做法:我们考虑对第一棵树边分治,然后分为 $S_L, S_R$ 过后,每个点带一个权值 $v(x)$,我们需要求 $\max_{x\in S_L, y\in S_R} \{ v(x) + v(y) + \text{dist}_2 (x, y) \}$。考虑直接对 $S_L\cup S_R$ 在第二棵树上建虚树,这样的话,我们只需要求 $x$ 子树内部 $\max_{u\in S_L} \text{dist}_2(u,x) + v(u)$ 和类似的对于 $S_R$ 的定义,就可以在 $O(sz)$ 内完成。于是总复杂度就是 $O(n\log n)$,如果使用神秘排序建虚树之类的话可以做到 $O(n\log n)$。

现在考虑我们多出来了一棵树怎么做。我们像上面那样建虚树过后合并 $S_1, S_2$ 两棵子树的答案的时候相当于是 $\max_{x\in S_1\cap S_L, y\in S_2\cap S_R} \{v(x) + v(y) + p_x + p_y - 2p_u + \text{dist}_3(x, y) \}$,当然 $x, y$ 反过来还可以贡献。注意到 $u$ 是当前点不需要管,我们只需要把剩余部分的最大值求好即可。这个本质相当于是我们在第三棵树的 $x$ 下面挂了一个长度为 $v(x) + p_x$ 的边,然后求集合内部的直径。容易发现计算集合的直径的时候可以直接从下方合并,于是就做完了。

时间复杂度 $O(n\log ^ 2 n)$ 或 $O(n\log n)$,可能带有较大的常数,可以通过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
struct Node {
int a[2][2];
Node() : a{} {}
auto& operator [](int x) { return a[x]; }
Node operator *(Node b) const {
Node c;
for (int i : {0, 1}) {
LL mx = -1;
int tmp[4]{a[i][0], a[i][1], b[i][0], b[i][1]};
for (int j = 0; j < 4; ++ j)
for (int k = j + 1; k < 4; ++ k)
if (chkmax(mx, d1[tmp[j]] + d1[tmp[k]] + T2.pre[tmp[j]] + T2.pre[tmp[k]]
+ T3.dist(tmp[j], tmp[k])))
c[i][0] = tmp[j], c[i][1] = tmp[k];
}
return c;
}
};

void edge_prepare(int x, int fa = 0)
{
int ls = 0;
for (int i = h4[x], v; ~i; i = ne[i]) {
if ((v = e[i]) == fa) continue;
if (!ls) ls = x;
else link(h1, ls, ++ cnt1, 0), ls = cnt1;
link(h1, ls, v, w[i]), edge_prepare(v, x);
}
}

int get_size(int x, int fa = 0)
{
int sz = 1;
for (int i = h1[x], v; ~i; i = ne[i])
if (!vis[i >> 1] && (v = e[i]) != fa) sz += get_size(v, x);
return sz;
}

int get_ec(int x, int frm, int tot, int &mx, int &eid)
{
int sz = 1;
for (int i = h1[x]; ~i; i = ne[i])
if (!vis[i >> 1] && i != (frm ^ 1)) sz += get_ec(e[i], i, tot, mx, eid);
if (chkmin(mx, std::max(sz, tot - sz))) eid = frm;
return sz;
}

void get_side(int x, int fa, LL dis, std::vector<int> &nds)
{
if (x <= n) nds.push_back(x), d1[x] = dis;
for (int i = h1[x], v; ~i; i = ne[i])
if (!vis[i >> 1] && (v = e[i]) != fa) get_side(v, x, dis + w[i], nds);
}

void insert_vir(int x)
{
if (stk[top] == x) return;
int lca = T2.LCA(x, stk[top]);
while (top > 1 && T2.dep[stk[top - 1]] >= T2.dep[lca])
link(h5, stk[top], stk[top - 1], 0), top --;
if (T2.dep[stk[top]] > T2.dep[lca]) link(h5, stk[top], lca, 0), top --;
if (stk[top] != lca) stk[++ top] = lca;
stk[++ top] = x;
}

Node dfs_ans(int x, int fa = 0)
{
Node ret{};
if (sid[x] == 2) ret[1][0] = ret[1][1] = x;
else if (sid[x] == 1) ret[0][0] = ret[0][1] = x;
auto getval = [&](int u, int v) {
if (!u || !v) return 0LL;
return d1[u] + d1[v] + T2.pre[u] + T2.pre[v] - 2 * T2.pre[x] + T3.dist(u, v);
};
for (int i = h5[x], v; ~i; i = ne[i])
{
if ((v = e[i]) == fa) continue;
// std::cout << "Edge " << x << ' ' << v << ' ' << i << std::endl;
Node to = dfs_ans(v, x);
for (int i : {0, 1})
for (int j : {0, 1})
for (int k : {0, 1})
chkmax(cans, getval(ret[i][j], to[i ^ 1][k]));
ret = ret * to;
}
return ret;
}

void dfs_clear(int x, int fa = 0)
{
for (int i = h5[x], v; ~i; i = ne[i])
if ((v = e[i]) != fa) dfs_clear(v, x);
h5[x] = -1, sid[x] = 0;
}

void divide(int x)
{
int sz = get_size(x), eid = -1;
if (sz == 1) return;
get_ec(x, -1, sz, sz, eid), vis[eid >> 1] = true;
// std::cout << "Divide " << e[eid] << ' ' << e[eid ^ 1] << std::endl;
std::vector<int> nds, ndr;
get_side(e[eid], e[eid ^ 1], 0, nds), get_side(e[eid ^ 1], e[eid], 0, ndr);
for (int x : nds) sid[x] = 1;
for (int x : ndr) sid[x] = 2, nds.push_back(x);
std::sort(nds.begin(), nds.end(), [&](int x, int y) {
return T2.fi[x] < T2.fi[y];
});
stk[top = 1] = 1;
int bac = idx;
for (int x : nds) insert_vir(x);
while (-- top)
link(h5, stk[top + 1], stk[top], 0);
cans = 0;
dfs_ans(1, 0), chkmax(res, cans + w[eid]), dfs_clear(1, 0), idx = bac;
// std::cout << cans << ' ' << w[eid] << '\n';
divide(e[eid]), divide(e[eid ^ 1]);
}

int main()
{
for (int i = 0; i < N; ++ i) h1[i] = h2[i] = h3[i] = h4[i] = h5[i] = -1;
for (int i = N; i < 2 * N; ++ i) h1[i] = -1;
std::cin >> n;
int u, v;
LL w;
for (int i = 1; i < n; ++ i)
{
scanf("%d %d %lld", &u, &v, &w);
link(h4, u, v, w);
}
for (int i = 1; i < n; ++ i)
{
scanf("%d %d %lld", &u, &v, &w);
link(h2, u, v, w);
}
for (int i = 1; i < n; ++ i)
{
scanf("%d %d %lld", &u, &v, &w);
link(h3, u, v, w);
}
cnt1 = n, edge_prepare(1), T2.prework(h2), T3.prework(h3), T4.prework(h4);
divide(1);
std::cout << res << '\n';
return 0;
}