利用好了期望的性质,除了线性基几乎没有卡点,非常有意思。
题意:给定一个长度为 $n$ 的序列 $a$,任选一个子序列,求 $(\oplus_{i\in S} a_i) ^ k$ 的期望并输出准确值。$n\leq 10 ^ 5$,保证答案小于 $2 ^ {63}$,$k\leq 5$。
答案不超过 $2 ^ {63}$ 大概提示我们对于不同的 $k$ 有不同的 $a_i$ 数据范围。
首先 $k = 1$ 的情况是好做的,根据期望的线性性,每位计算,如果有一个数的某一位出现了 1,那么选与不选的概率相同,也就是 01 的概率相同。那么最后的答案就是或的和除以 2。
然后考虑 $k = 2$ 的情况。一个经典的做法是枚举两位(可以相同),计算同时为 1 的概率并乘上贡献。全是 00 贡献为 0,如果全是 00 或 11,那么两位只绑在一起的,概率为 $\dfrac 12$。否则出现一个 01 或者 10 的话,假设前面都是 11 或 00,那么出现 11 的概率和 00 的概率都是 $\dfrac 12$。来了一个 01 或者 10,11 和 00 想要不变,就不能选,那么 11 和 00 的概率都变成了 $\dfrac 14$。而 01 和 10 的概率也都变成了 $\dfrac 14$。四者概率相同后,后面的就没法再改变了。于是出现 01 或 10 的情况,就是 $\dfrac 14$ 的概率。
这时答案的二倍一定是整数,因为如果两个选择的都是最低位的话,概率不会是 $\dfrac 14$,所以最后判一下是奇数还是偶数,除以 2 输出即可。
然后考虑 $k\geq 3$ 的情况。这时由于线性基一定能表示原来所有数能表示的范围,但是线性基现在只有不超过 $21$ 位,那么我们直接暴力枚举所有可能的情况即可。其实对于剩下能被线性基元素线性表示的元素是没有意义的,因为如果选了这个元素,相当于线性基内表示他的元素出现次数 $\oplus 1$。这样只需考虑线性基内的元素,直接爆搜即可。
最后一个问题就是最后的小数如何输出。可能爆 unsigned long long
,所以把小数部分和整数部分分开存,这样都不会爆。还有一个结论是答案的二倍还是整数。证明可以考虑类似 $k = 2$ 的证法,比较麻烦,就不讲了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
| struct LinearBasis { ULL a[N]; ULL& operator [](int x) { return a[x]; } bool insert(ULL x) { for (int i = 63; ~i; -- i) { if (!(x >> i & 1)) continue; if (!a[i]) return a[i] = x, true; x ^= a[i]; } return false; } } chk;
int main() { std::cin >> n >> k; for (int i = 1; i <= n; ++ i) scanf("%llu", a + i); if (k == 1) { ULL res = 0; for (int i = 1; i <= n; ++ i) res |= a[i]; printf("%llu", res >> 1); if (res & 1) puts(".5"); return 0; } if (k == 2) { ULL res = 0; for (int b1 = 0; b1 < 32; ++ b1) for (int b2 = 0; b2 < 32; ++ b2) { bool flag1 = false, flag2 = false; for (int i = 1; i <= n && (!flag1 || !flag2); ++ i) flag1 |= a[i] >> b1 & 1, flag2 |= a[i] >> b2 & 1; if (!flag1 || !flag2) continue; bool dif = false; for (int i = 1; i <= n && !dif; ++ i) if ((a[i] >> b1 & 1) ^ (a[i] >> b2 & 1)) dif = true; res += 1ULL << (b1 + b2 - dif); } printf("%llu", res >> 1); if (res & 1) puts(".5"); return 0; } int sz = 0; for (int i = 1; i <= n; ++ i) sz += chk.insert(a[i]); ULL res1 = 0, res2 = 0; std::vector<ULL> all; for (int i = 0; i < 63; ++ i) if (chk[i]) all.push_back(chk[i]); for (int s = 1; s < (1 << sz); ++ s) { ULL cur = 0; for (int i = 0; i < sz; ++ i) if (s >> i & 1) cur ^= all[i]; ULL a = 0, b = 1; for (int cs = 1; cs <= k; ++ cs) a *= cur, b *= cur, a += b >> sz, b &= (1 << sz) - 1; res1 += a, res2 += b; } res1 += res2 >> sz, res2 &= (1 << sz) - 1; std::cout << res1; if (res2) puts(".5"); return 0; }
|