BZOJ3811 玛里苟斯

利用好了期望的性质,除了线性基几乎没有卡点,非常有意思。

题意:给定一个长度为 $n$ 的序列 $a$,任选一个子序列,求 $(\oplus_{i\in S} a_i) ^ k$ 的期望并输出准确值。$n\leq 10 ^ 5$,保证答案小于 $2 ^ {63}$,$k\leq 5$。

答案不超过 $2 ^ {63}$ 大概提示我们对于不同的 $k$ 有不同的 $a_i$ 数据范围。

首先 $k = 1$ 的情况是好做的,根据期望的线性性,每位计算,如果有一个数的某一位出现了 1,那么选与不选的概率相同,也就是 01 的概率相同。那么最后的答案就是或的和除以 2。

然后考虑 $k = 2$ 的情况。一个经典的做法是枚举两位(可以相同),计算同时为 1 的概率并乘上贡献。全是 00 贡献为 0,如果全是 00 或 11,那么两位只绑在一起的,概率为 $\dfrac 12$。否则出现一个 01 或者 10 的话,假设前面都是 11 或 00,那么出现 11 的概率和 00 的概率都是 $\dfrac 12$。来了一个 01 或者 10,11 和 00 想要不变,就不能选,那么 11 和 00 的概率都变成了 $\dfrac 14$。而 01 和 10 的概率也都变成了 $\dfrac 14$。四者概率相同后,后面的就没法再改变了。于是出现 01 或 10 的情况,就是 $\dfrac 14$ 的概率。

这时答案的二倍一定是整数,因为如果两个选择的都是最低位的话,概率不会是 $\dfrac 14$,所以最后判一下是奇数还是偶数,除以 2 输出即可。

然后考虑 $k\geq 3$ 的情况。这时由于线性基一定能表示原来所有数能表示的范围,但是线性基现在只有不超过 $21$ 位,那么我们直接暴力枚举所有可能的情况即可。其实对于剩下能被线性基元素线性表示的元素是没有意义的,因为如果选了这个元素,相当于线性基内表示他的元素出现次数 $\oplus 1$。这样只需考虑线性基内的元素,直接爆搜即可。

最后一个问题就是最后的小数如何输出。可能爆 unsigned long long,所以把小数部分和整数部分分开存,这样都不会爆。还有一个结论是答案的二倍还是整数。证明可以考虑类似 $k = 2$ 的证法,比较麻烦,就不讲了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
struct LinearBasis {
ULL a[N];
ULL& operator [](int x) { return a[x]; }
bool insert(ULL x)
{
for (int i = 63; ~i; -- i)
{
if (!(x >> i & 1)) continue;
if (!a[i]) return a[i] = x, true;
x ^= a[i];
}
return false;
}
} chk;

int main()
{
std::cin >> n >> k;
for (int i = 1; i <= n; ++ i) scanf("%llu", a + i);
if (k == 1) {
ULL res = 0;
for (int i = 1; i <= n; ++ i) res |= a[i];
printf("%llu", res >> 1);
if (res & 1) puts(".5");
return 0;
}
if (k == 2) {
ULL res = 0;
for (int b1 = 0; b1 < 32; ++ b1)
for (int b2 = 0; b2 < 32; ++ b2)
{
bool flag1 = false, flag2 = false;
for (int i = 1; i <= n && (!flag1 || !flag2); ++ i)
flag1 |= a[i] >> b1 & 1, flag2 |= a[i] >> b2 & 1;
if (!flag1 || !flag2) continue;
bool dif = false;
for (int i = 1; i <= n && !dif; ++ i)
if ((a[i] >> b1 & 1) ^ (a[i] >> b2 & 1)) dif = true;
res += 1ULL << (b1 + b2 - dif);
}
printf("%llu", res >> 1);
if (res & 1) puts(".5");
return 0;
}
int sz = 0;
for (int i = 1; i <= n; ++ i) sz += chk.insert(a[i]);
ULL res1 = 0, res2 = 0;
std::vector<ULL> all;
for (int i = 0; i < 63; ++ i)
if (chk[i]) all.push_back(chk[i]);
for (int s = 1; s < (1 << sz); ++ s)
{
ULL cur = 0;
for (int i = 0; i < sz; ++ i)
if (s >> i & 1) cur ^= all[i];
ULL a = 0, b = 1;
for (int cs = 1; cs <= k; ++ cs)
a *= cur, b *= cur, a += b >> sz, b &= (1 << sz) - 1;
res1 += a, res2 += b;
}
res1 += res2 >> sz, res2 &= (1 << sz) - 1;
std::cout << res1;
if (res2) puts(".5");
return 0;
}