UOJ559 [NOI2020]命运

题意:给定一棵 $n$ 个点的有根树和 $m$ 条直上直下的链,问有多少种给每条边赋值 0 或 1 的方案数使得 $m$ 条链至少有一个 1。$n, m\leq 5\times 10 ^ 5$,2s,1024 MB。

题意明摆着叫你容斥,那么我们就容斥呗(但是好像有简单 DP 做法?)。暴力容斥,钦定一些链必须全部为 0,然后乘上 $(-1) ^ x$ 贡献到答案。直接做是 $O(2 ^ m n)$,用 DP 优化,考虑设 $f_{i, j}$ 表示 $i$ 为根的子树内已经选好,且从 $i$ 到 $j$ 链($j$ 是 $i$ 的祖先)已经全部被钦定为 0 了的方案数。这样就可以写出转移:

$j, k$ 按照 $dep$ 比较,$u, v$ 是一条直上直下的链。这样直接做是 $O(n ^ 2)$ 的,仍然无法通过。

我们观察转移,注意到非 0 项的个数其实并不多,产生一个非零项的方法只有从 $f_{x, x}$ 和从 $-f_{x, x}\rightarrow f_{x, u_i}$ 两个部分。容易发现产生的总 0 项的个数是 $O(n + m)$ 级别的。其余都是在不同的 $x$ 之间转移,于是这个可以使用线段树合并转移,时空复杂度均为 $O((n + q)\log n)$,可以通过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
void pushup(int x) { adj(tr[x].sum = tr[tr[x].lc].sum + tr[tr[x].rc].sum - Mod); }

void update(int x, int c) {
if (!x) return;
tr[x].sum = (LL) tr[x].sum * c % Mod;
tr[x].lt = (LL) tr[x].lt * c % Mod;
}

void pushdown(int x)
{
if (tr[x].lt == 1) return;
update(tr[x].lc, tr[x].lt), update(tr[x].rc, tr[x].lt);
tr[x].lt = 1;
}

int merge(int p, int q, int l, int r, int sump, int sumq)
{
if (!p && !q) return 0;
if (!p) return update(q, sump), q;
if (!q) return update(p, sumq), p;
if (l == r) {
tr[p].sum = ((LL) (tr[p].sum + sump) * tr[q].sum + (LL) sumq * tr[p].sum) % Mod;
return p;
}
pushdown(p), pushdown(q);
int mid = (l + r) >> 1, rp = tr[tr[p].rc].sum, rq = tr[tr[q].rc].sum;
tr[p].lc = merge(tr[p].lc, tr[q].lc, l, mid,
adj(rp = sump + rp - Mod), adj(rq = sumq + rq - Mod));
tr[p].rc = merge(tr[p].rc, tr[q].rc, mid + 1, r, sump, sumq);
return pushup(p), p;
}

int query(int x, int ql, int qr, int l, int r)
{
if (!x || ql > r || qr < l) return 0;
if (l >= ql && r <= qr) return tr[x].sum;
pushdown(x);
int mid = (l + r) >> 1;
return query(tr[x].lc, ql, qr, l, mid)
+ query(tr[x].rc, ql, qr, mid + 1, r);
}

void modify(int &x, int l, int r, int pos, int c)
{
if (!x) x = ++ nodecnt;
if (l == r) return void(adj(tr[x].sum += c - Mod));
pushdown(x);
int mid = (l + r) >> 1;
if (pos <= mid) modify(tr[x].lc, l, mid, pos, c);
else modify(tr[x].rc, mid + 1, r, pos, c);
pushup(x);
}

void rdfs(int x, int fa = 0)
{
std::sort(con[x].begin(), con[x].end(), std::greater<int>());
con[x].erase(std::unique(con[x].begin(), con[x].end()), con[x].end());
if (con[x].size())
modify(rt[x], 1, n, con[x][0], Mod - 1);
modify(rt[x], 1, n, dep[x], 1);
for (int v : g[x])
{
if (v == fa) continue;
rdfs(v, x);
int extra = query(rt[v], dep[v], dep[v], 1, n);
modify(rt[v], 1, n, dep[v], Mod - extra);
modify(rt[v], 1, n, dep[x], extra * 2 % Mod);
rt[x] = merge(rt[x], rt[v], 1, n, 0, 0);
}
/*for (int i = 1; i <= dep[x]; ++ i) std::cout << query(rt[x], i, i, 1, n) << ' ';
std::cout << '\n';*/
}

int main()
{
init();
std::cin >> n;
for (int i = 1, u, v; i < n; ++ i)
{
scanf("%d %d", &u, &v);
g[u].push_back(v), g[v].push_back(u);
}
pre_dfs(1);
std::cin >> m;
for (int i = 1, u, v; i <= m; ++ i)
{
scanf("%d %d", &u, &v);
con[v].push_back(dep[u]);
}
rdfs(1);
std::cout << query(rt[1], 1, 1, 1, n) << '\n';
return 0;
}