LOJ2988 [CTSC2016]萨菲克斯 · 阿瑞

奇怪的后缀数组计数。

题意:给定 $m$ 中互不相同的字符,每一个字符有 $c_i$ 个,问所有由这些字符组成的长度为 $n$ 的字符串中,不同的后缀数组有多少个。对 $10 ^ 9 + 7$ 取模,$n, m\leq 500$。

后缀数组是不好统计的,考虑如何将后缀数组转化为其他的一些东西。注意到我们可以把后缀数组不同位置的关系一一找出来,那么 $n$ 个字符就可以转化为一个不等式链。具体的,如果 $sa_{i + 1} + 1$ 排名比 $sa_i + 1$ 小,那么 $s_{sa_i} < s_{sa_{i + 1}}$,否则 $s_{sa_i}\leq s_{sa_{i + 1}}$。

注意到一个“满”的不等式链一定和一个后缀数组形成双射,其中“满”表示每一处的 $<$ 都确实是 $<$,不能变成 $\leq$。

现在我们就可以通过不等式链的方案数来反向计算后缀数组的个数。假设我们已经出现了 $k - 1$ 个 $<$,将原不等式链分成了 $k$ 段,分别为 $a_1, a_2\cdots, a_k$,那么不等式链的方案数就是 $\dfrac{n!}{\prod_{i = 1} ^ k a_i!}$。

这样看似很对,但是注意到一个问题:我们在计算 $k$ 段的时候,可能中间某些 $<$ 是不必要的,也就是说,在某些情况下,这个 $<$ 其实是一个 $\leq$,它并不是“满”的。这样就有可能算重了,因为不是“满”的就有可能被“满”的方案重新算一次,而这两个对应的后缀数组又是相同的。

另外一个可能算重的地方就是如果 $c_i$ 未用完,但是 $i$ 和 $i + 1$ 的交界处确实用的是 $\leq$,那么这个也是算重了的,因为我们可以把 $i + 1$ 替换成 $i$,这样就会算重。

接下来我们考虑刚才的 $\leq$ 和 $<$ 的算重问题。因为我们无法限制某一个位置一定是 $<$,所以只好考虑容斥,如果我们把 $x$ 个 $<$ 强制变为了 $\leq$,那么贡献需要乘上 $(-1) ^ x$。

接下来的事情就是比较套路的 DP 优化容斥了。设 $f(i, j, k)$ 表示考虑到第 $i$ 个字符,目前长度为 $j$,最后一段 $\leq$ 链长度为 $k$ 的方案数。注意我们这里使用类似于 EGF 的办法,先把所有的 $\dfrac{1}{a_i!}$ 乘起来,最后乘 $n!$。考虑转移:

  1. 正常的把这一段 $c_i$ 个全部接在前面 $\leq$ 链上:$f(i, j, k) \times 1\to f(i + 1, j + c_i, k + c_i)$。
  2. 可能只选择一部分,直接将该 $\leq$ 链正常结束:$f(i, j, k)\times \dfrac1{(k + l)!}\to f(i + 1, j + l, 0)$,其中 $l\in [0, c_i]$。
  3. 可能只选择一部分,本应该是 $<$,但是容斥为 $\leq$:$f(i, j, k) \times -1\to f(i + 1, j + l, k + l)$。

直接转移是 $O(n ^ 3 m)$ 的,可以考虑前缀和,把 $(j, k)$ 看作二维平面,那么第 2、3 部分计算的都是对角线的一部分,按照对角线方向前缀和即可做到 $O(n ^ 2m)$,可以通过。

不得不说,这道题后缀数组转化为不等式链的方法还是很巧妙的,另外 DP 优化容斥还是一个不错的题目。注意 $c_i = 0$ 时需要跳过,不然可能出现奇怪的无法转移的错误。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
int main()
{
init();
std::cin >> n >> m;
for (int i = 1; i <= m; ++ i)
{
scanf("%d", c + i);
if (!c[i]) -- i, -- m;
}
// f_{j, k} *-1 -> f_{j + l, k + l},
// *infact[k + l] -> f_{j + l, 0}
// -> f_{j + cnt, k + cnt}
f[0][0] = 1;
int res = 0;
for (int i = 1; i <= m; ++ i)
{
int c = ::c[i];
for (int j = 0; j <= n; ++ j)
for (int k = 0; k <= j; ++ k)
{
sum[j][k] = f[j][k];
if (j && k) adj(sum[j][k] += sum[j - 1][k - 1] - Mod);
f[j][k] = 0;
}
for (int j = 1; j <= n; ++ j)
for (int k = 1; k <= j; ++ k)
{
adj(f[j][k] -= sum[j - 1][k - 1]);
f[j][0] = (f[j][0] + (LL) sum[j - 1][k - 1] * infact[k]) % Mod;
if (j >= c && k >= c)
adj(f[j][k] += sum[j - c][k - c] - Mod);
if (j > c && k > c)
f[j][0] = (f[j][0] + (LL) sum[j - c - 1][k - c - 1] * (Mod - infact[k])) % Mod;
}
adj(res += f[n][0] - Mod);
}
res = (LL) res * fact[n] % Mod;
std::cout << res << '\n';
return 0;
}