CF1515H Phoenix and Bits

题意:维护一个集合 $a$,最开始有 $n$ 个元素,有 $m$ 次操作或者询问:

  1. 将 $a_i\in [l, r]$ 的值全部与 $x$。
  2. 将 $a_i\in [l, r]$ 的值全部或 $x$。
  3. 将 $a_i\in [l, r]$ 的值全部异或 $x$。
  4. 询问在 $[l, r]$ 中有多少个不同的 $a_i$。

$n\leq 2\times 10 ^ 5$,$m\leq 10 ^ 5$,$0\leq a_i < 2 ^ {20}$。

进阶 01 Trie 的模板题。

寻找一段区间

类似于 FHQ Treap 的 split 一样,我们考虑在 $[l, r]$ 之间分离出一段 $[x, y]$。此时考虑 $mid$ 在 $[x, y]$ 的位置,向左右递归即可。节点 split 直接新建即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void split(int u, int &x, int &y, int l, int r, int bit)
{
x = y = 0;
if (!u || !tr[u].sz) return;
if (bit < 0 || (l == 0 && r == 2 * (1 << bit) - 1)) {
x = u;
return;
}
int mid = 1 << bit;
// std::cout << u << ' ' << l << ' ' << r << ' ' << bit << ' ' << tr[u].sz << '\n';
pushdown(u, bit);
if (l >= mid)
y = u, split(tr[u].s[1], tr[x = ++ tot].s[1], tr[u].s[1], l ^ mid, r ^ mid, bit - 1);
else if (r < mid)
y = u, split(tr[u].s[0], tr[x = ++ tot].s[0], tr[u].s[0], l, r, bit - 1);
else {
// if (l == 0 && r == mid * 2 - 1) return x = u, void();
x = u, y = ++ tot;
split(tr[u].s[0], tr[u].s[0], tr[y].s[0], l, mid - 1, bit - 1);
split(tr[u].s[1], tr[u].s[1], tr[y].s[1], 0, r ^ mid, bit - 1);
}
pushup(x), pushup(y);
}

合并两棵子树

没什么好讲的,注意到了叶子节点只能保留一个信息即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int merge(int x, int y, int bit)
{
// std::cout << "MERGE " << x << ' ' << y << ' ' << bit << '\n';
if (!x || !tr[x].sz) return y;
if (!y || !tr[y].sz) return x;
if (bit < 0) {
// std::cout << tr[x].val << ' ' << tr[y].val << std::endl;
assert(!tr[x].sz || !tr[y].sz || tr[x].val == tr[y].val);
return tr[x].sz ? x : y;
}
pushdown(x, bit), pushdown(y, bit);
tr[x].s[0] = merge(tr[x].s[0], tr[y].s[0], bit - 1);
tr[x].s[1] = merge(tr[x].s[1], tr[y].s[1], bit - 1);
return pushup(x), x;
}

对全局 xor

我们考虑对于一棵子树全部 xor 怎么做。因为我们已经有 split 了,所以这个操作可以支持区间做。

这个可以使用懒标记实现。因为子树的节点个数并没有发生改变,而只是对于一些节点的左右儿子交换即可。这个节点的所有信息都可以在不递归子树的情况下 $O(1)$ 维护。

具体有哪些信息呢?我们目前只需要维护左右儿子、懒标记和不同的数的个数,都很简单,就不讲了。

对全局 or

全局 or 是不好做的,因为我们可能合并节点之类的。虽然总结点个数是 $O((n + m)\log a)$ 的,但是我们并不好判断一棵子树内部有没有需要合并的节点。

首先如果全局 or 在一定情况下可以变成全局 xor,这样不会减少节点,所以我们为了保证复杂度,不能递归。容易发现这个的成立条件是在给定 or 的每一位上全局要么全是 1,要么全是 0。我们考虑记录全局的 or $v_1$ 和全局补集的 or $v_2$(将所有数取反的 or),那么判断条件可以写作 $[(v_2\odot x)\odot v_1 = 0]$($\odot$ 表示与运算)。另外,此时我们相当于要对全局 xor 上 $v_2\odot x$。

否则的话,我们至少会出现一个节点的合并,这个时候我们再考虑递归。如果当前节点左右儿子都存在,并且我们这一位会 or 上 1,那么我们可以先将左儿子的所有值 xor 上 $2 ^ {bit}$,然后直接将左右儿子都合并到右儿子即可。

这个时候再反过去在 xor 的时候维护一下 or 和补集的 or,注意到如果 or 和补集的 or 都包含某一位的话,怎么 xor 这一位都是 1。拆位做一下就好了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void allxor(int x, int bit, int lt) {
if (!x) return;
int v = tr[x].val, lv = tr[x].lval;
if (lt >> bit & 1) std::swap(tr[x].s[0], tr[x].s[1]);
tr[x].val = (v ^ lt) | (v & lv & lt);
tr[x].lval = (lv ^ lt) | (v & lv & lt);
tr[x].lt ^= lt;
}

void pushdown(int x, int bit)
{
if (!x || !tr[x].lt || !tr[x].sz) return;
allxor(tr[x].s[0], bit - 1, tr[x].lt), allxor(tr[x].s[1], bit - 1, tr[x].lt);
tr[x].lt = 0;
}

void allor(int u, int bit, int val)
{
if (!u || !tr[u].sz) return;
int add = val & tr[u].lval;
if (!(add & tr[u].val)) return allxor(u, bit, add);
pushdown(u, bit);
int mid = 1 << bit;
if (val & mid)
allxor(tr[u].s[0], bit - 1, mid),
tr[u].s[1] = merge(tr[u].s[0], tr[u].s[1], bit - 1), tr[u].s[0] = 0;
allor(tr[u].s[0], bit - 1, val), allor(tr[u].s[1], bit - 1, val);
pushup(u);
}

对全局 and

容易发现 $v\odot x = \neg(\neg v |x)$,那么就可以拆分成 or 和 xor 操作。这个就好做了。

总结

由于一共只有 $O((n + m)\log a)$ 个节点,我们花费 $O(\log a)$ 的时间可以永久删除一个点,所以复杂度时 $O((n + m)\log ^ 2a)$ 的,可以通过。

由于笔者也是第一次也这种代码,出了很多神秘错误,大概列几点可能常见的:

  1. 计算低位的时候,不管 xor 还是 or 都不能只把 $x$ 的低位传下去,因为需要维护子树的值相关的信息。
  2. 由于 split 时可能出现里面一个数都没有的节点,操作时需要特判一下,比如把所有判断空节点都写成 !tr[u].sz
  3. 注意 or 判断改为 xor 时的判断条件要准确,因为小样例测不出来。

只放主函数了,调用的函数上面基本出现了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int main()
{
std::cin >> n >> m;
for (int i = 1, x; i <= n; ++ i) scanf("%d", &x), insert(x);
for (int i = 1, typ, l, r, v, x, y; i <= m; ++ i) {
scanf("%d %d %d", &typ, &l, &r);
split(rt, x, y, l, r, 19);
// std::cout << "RANGE " << tr[x].sz << ' ' << tr[x].val << ' ' << tr[x].lval << '\n';
if (typ == 4) printf("%d\n", tr[x].sz);
else if (typ == 3) scanf("%d", &v), allxor(x, 19, v);
else if (typ == 2) scanf("%d", &v), allor(x, 19, v);
else scanf("%d", &v), allxor(x, 19, U), allor(x, 19, U ^ v), allxor(x, 19, U);
rt = merge(y, x, 19);
// std::cout << "ALL " << tr[rt].sz << ' ' << tr[rt].val << ' ' << tr[x].lval << "\n\n";
// exit(0);
}
return 0;
}