LOJ2143 [SHOI2017]组合数问题

有意思的一道题目,做法比较多。

题意:求:

$n\leq 10 ^ 9$,$k\leq 50$,$p < 2 ^ {30}$。

做法 1:单位根反演

看到 $[i\bmod k = r]$,果断单位根反演。

但是我们前面说到 $k|p - 1$ 才有单位根,那怎么办呢?其实可以直接能成一个多项式的形式,类似于 $\sum_{i = 0} ^ {k - 1} a_i \omega_k ^ {i}$ 来代替单个数,这样可以实现乘法加法运算,可以得到最终解、

有一个问题就是最后的数不一定是只有 $\omega_k ^ {0}$ 位置有数,但是我们前面有看到答案一定是整数。这里可以按照 $\sum_{i = 0 } ^ {k - 1}\omega_k ^ {i} = 0$,$\omega_k ^ {\frac k2} = -1$ 等式子化简,可以证明最后一定可以得到正确答案。另外一个问题就是 $\dfrac 1k$ 并不好处理、一个神秘的办法是按照 $p\times k$ 取模计算,由于答案是整数,直接 最后 $\times \dfrac 1k$ 即可。

时间复杂度 $O(k ^ 3\log n)$ 或者是 $O(k ^ 2\log k\log n)$,有没有更优的还没有细究。

代码写了但是没写对,有时间来重写(

做法 2:矩阵乘法

我们可以看到 $\displaystyle \binom ni$ 的组合意义,就是走 $n$ 步,每一次可以走到 $(i + 1, j)$ 或者是 $(i + 1, j + 1)$,最后走到了 $(n, i)$ 这个位置的方案数。而这个式子又是可以通过递推得到的。

然后怎么把 $\bmod k = r$ 加进去呢?我们就强制走到 $(x, k)$ 就是 $(x, 0)$,这样最后回到 $r$ 这个位置的方案数就是答案。

我们就可以很容易的写出矩阵的转移方程, 矩阵乘法即可,时间复杂度 $O(k ^ 3\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

struct Matrix {
std::vector<std::vector<int>> a;
Matrix() : a(k, std::vector<int>(k)) {}
auto& operator [](int x) { return a[x]; }
Matrix operator *(Matrix b) const {
Matrix res;
for (int i = 0; i < k; ++ i)
for (int j = 0; j < k; ++ j)
for (int l = 0; l < k; ++ l)
res[i][l] = (res[i][l] + (LL) a[i][j] * b[j][l]) % Mod;
return res;
}
};

Matrix qpow(Matrix a, LL d)
{
Matrix res;
for (int i = 0; i < k; ++ i) res[i][i] = 1;
for (; d; d >>= 1, a = a * a)
if (d & 1) res = res * a;
return res;
}

int main()
{
std::cin >> n >> Mod >> k >> r, n *= k;
Matrix trs;
for (int i = 0; i < k; ++ i) trs[i][(i + 1) % k] ++, trs[i][i] ++;
trs = qpow(trs, n);
std::cout << trs[0][r] << '\n';
return 0;
}

做法 3:生成函数

看到组合数,我们很自然的(?)就想到了二项式定理,那么不考虑模数的话,生成函数就是 $(1 + x) ^ n$,然后如果需要得到次数 $\bmod k$ 的结果,直接对长度为 $k$ 的多项式循环卷积即可。

暴力卷积 $O(k ^ 2\log n)$,使用 BlueStein 算法 循环卷积可以做到 $O(k\log k)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

poly operator *(poly a, poly b) {
poly res(k);

for (int i = 0; i < k; ++ i)
for (int j = 0; j < k; ++ j)
res[(i + j) % k] = (res[(i + j) % k] + (LL) a[i] * b[j]) % Mod;

return res;
}

int main() {
std::cin >> n >> Mod >> k >> r, n *= k;
poly st(k), a(k);
a[0] ++, a[1 % k] ++, st[0] = 1;

for (; n; n >>= 1, a = a * a)
if (n & 1)
st = st * a;

std::cout << st[r] << '\n';
return 0;
}