后缀数组

和后缀自动机几乎是字符串最难的两个专题了。

注意很多题不能相互替代。

1. 主要思想

首先,我们一般要依靠两个算法:

  1. 倍增 $O(n\log n)$
  2. DC3 $O(n)$,常数较大。

一般使用 $O(n \log n)$ 的倍增算法。

该算法是针对字符串的,可以在 $O(n\log n)$ 的时间内将所有的后缀全部排序。

假设下标从 1 开始,其中从 $i$ 开始的后缀被称为第 $i$ 个后缀。显然没有两个 $s[i]$ 是相同的。

排序后,我们可以得到 $sa[1…n]$ 的数组,代表排名为 $i$ 的是第几个后缀。按照字典序排序。定义 $s[i]$ 为从 $i$ 开始的后缀。注意有时直接使用 $i$ 来代替 $s[i]$,请注意。

还可以得到 $rk[1…n]$ 代表第 $i$ 个后缀的排名是多少。显然 $sa[rk[i]] = i$。

还有一个比较重要而常用的数组 $height[n]$,表示 $sa[i]$ 和 $sa[i-1]$ 的最长公共前缀

请注意排名为 $i$ 的 $sa[1…n]$ 和 $s[i]$ 的排名 $rk[1…n]$ 的区别。

2. 实现方法

首先,我们按第一个字符进行排序,相对位置不变(即如果第一个有相同的,在前面的还在前面)。

使用倍增。

假设当前已经处理了前 $k$ 个字符,我们将前 $k$ 个字符当做第一关键字,将接着的 $k$ 个字符当做第二关键字。这样就可以将每一个后缀按照前 $2k$ 个字符进行字典序排序。

没有的(即长度不满 $2k$)一定比长度 $\geq 2k$ 更小。

前 $k$ 个字符可以离散化,接着的 $k$ 个字符也可以离散化。

然后每一次排序就可以使用基数排序(不记得的先去复习一下)。

如果有两个关键字,可以先按第二关键字排序,再按第一关键字排序。

这里简单的讲解一下这里使用到的基数排序:

  1. 首先将第一关键字 $x$ 扔入一个桶 $cnt$ 中。
  2. 将桶做一遍前缀和。这时可以发现,$(x,y)$ 这一个元素的最大排名就是 $cnt[x]$。
  3. 然后,按照第二个关键字的逆序,将 $(x,y)$ 的排名赋值 $cnt[x]$。此时还没有枚举到的 $(x,y’)$(即 $x$ 相同,$y’$ 更小)的最大排名变小一位,于是 $cnt[x]\leftarrow cnt[x] - 1$。

如果还不理解,我们拿一个例子来看。

假设我们要排序 $(1,3),(2,2),(1,5),(3,7)$。

首先执行第一步,得到:$cnt[] = \{2, 1, 1\}$。

然后,前缀和得到: $cnt[] = \{2, 3, 4\}$。

然后按照 $y$ 逆序排序,首先枚举 $(3,7)$,得到 $rk[(3,7)] = 4$,$cnt[3]\leftarrow cnt[3] - 1$。

同理,然后会枚举 $(1,5)$,于是就是 $rk[(1,5)] = 2$,$cnt[1] = 1$。

枚举 $(1,3)$,得到 $rk[(1,3)] = cnt[1] = 1$,$cnt[1] = 0$。

最后枚举 $(2,2)$,略去。

由于基数排序是稳定排序,我们就没有影响这里的前后顺序。

我们给出代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void get_sa()
{
for (int i=1;i<=n;++i) c[x[i]=s[i]]++;
for (int i=1;i<=m;++i) c[i]+=c[i-1];
for (int i=n;i;i--) sa[c[x[i]]--]=i;//按照第一个字符排序,其实是 (s[i], i) 排序,使用基数排序
for (int k=1;k<=n;k<<=1)
{
int num=0;
for (int i=n-k+1;i<=n;++i) y[++num]=i;
for (int i=1;i<=n;++i)
if (sa[i]>k) y[++num]=sa[i]-k;//y[] 的意思是第二个关键字排名为 i 的后缀是哪一个。但 x[] 却是 i 为后缀在前 k 个排序时的排名
for (int i=1;i<=m;++i) c[i]=0;//等价于 cnt[],注意清空
for (int i=1;i<=n;++i) c[x[i]]++;
for (int i=2;i<=m;++i) c[i]+=c[i-1];
for (int i=n;i;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;//注意存的不是 rk[],而是 sa[],所以 y[i] 和 cnt[x[y[i]] --] 是不同的。等价与 rk[y[i]] = cnt[x[y[i]] --]
swap(x,y);//由于第二关键字已经完全使用了,我们不再需要,将第一个关键字排序的结果转到临时变量(是按照前 k 个的排序结果)。注意存的是每一后缀的排名。
x[sa[1]]=1,num=1;
for (int i=2;i<=n;++i)//明显比较前 2k 字符的时候,sa[i] 一定不会比 sa[i - 1] 的排名小。所以 num 不减。
if (y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])
x[sa[i]]=num;//如果 sa[i] 和 sa[i - 1] 的排名相等,且 sa[i] + k 和 sa[i - 1] + k 的排名相等,那么说明 sa[i] 和 sa[i - 1] 比较前 2k 字符时都是相等的。
else x[sa[i]]=++num;
if (num==n) return;
m=num;
}
}

再给一个简洁的代码,供参考。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
void get_sa(char *str, int *sa, int n)
{
static int c[N], x[N], y[N];
int m = 126, num;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i] = str[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[i]] --] = i;
for (int k = 1; k <= n; k <<= 1)
{
num = 0;
for (int i = n - k + 1; i <= n; ++ i) y[++ num] = i;
for (int i = 1; i <= n; ++ i)
if (sa[i] > k) y[++ num] = sa[i] - k;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[y[i]]] --] = y[i];
for (int i = 1; i <= n; ++ i) y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for (int i = 2; i <= n; ++ i)
if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = num;
else x[sa[i]] = ++ num;
if (num == n) return;
m = num;
}
}

怎样求和利用 $height[]$ 呢?

首先再次明确定义,$height[1…n]$ 表示的是 $sa[i]$ 和 $sa[i - 1]$ 的最长公共前缀。

我们定义 $lcp(i,j)$ 为 $s[sa[i]]$ 与 $s[sa[j]]$ 的最长公共前缀。

证明1:$lcp(i,j)=\min(lcp(i,k),lcp(k,j))[i\leq k \leq j]$。

首先证明 $\geq$。

很明显,如果 $lcp(i,k)>lcp(i,j),lcp(k,j)>lcp(i,j)$,那么 $i$ 和 $k$、$k$ 和 $j$ 都有更长的后缀,也就是 $i$ 和 $j$ 有更长的后缀。

再证明 $\leq$。

假设 $i$、$j$、$k$ 前 $lcp(i,j)$ 的串为 $A,B,C$。

很明显,有 $A\leq B\leq C$,又有 $A=C$,则 $A=B=C$。

证毕

为了方便,我们假设 $h(i)=height[rk[i]]$。就是指 $i$ 的后缀与排序后在 $i$ 前面一个的最长公共前缀。

证明2: $h(i)\geq h(i-1)-1$。

这里假设 $h(i - 1)\geq 1$。如果 $h(i - 1) = 0$,原式显然。

假设 $i-1$ 的前面一个是第 $k$ 个后缀,那么 $i-1$ 与 $k$ 的最长公共前缀就是 $h(i-1)$。

都去掉第一个字符,就可以得到 $k + 1$ 与 $i$ 的公共前缀是 $h(i-1)-1$。(因为去掉第一个字符,相当于 $i$ 为后缀的字符串变为了 $i + 1$ 为后缀的字符串。

因为 $k$ 在 $i-1$ 前面,那么 $k+1$ 在 $i$ 的前面。因为第一个字符是一样的,所以比较 $k$ 和 $i - 1$ 的比较其实是从 $k + 1$ 和 $i$ 的比较得来的。既然 $k$ 比 $i - 1$ 小,那么肯定 $k + 1$ 比 $i$ 小。

由前面的证明可以得到 $h(i)$ 大于等于 $i$ 与 $k+1$ 的最长公共前缀长度。

即 $h(i)\geq h(i-1)-1$。

证毕

有了这个结论,我们在每次求 $height[i]$ 时,就可以使用该结论。

1
2
3
4
5
6
7
8
9
10
11
12
void get_height()
{
for (int i=1;i<=n;++i) rk[sa[i]]=i;
for (int i=1;i<=n;++i)
{
if (rk[i]==1) continue;
int j=sa[rk[i]-1],k=max(0,height[rk[i-1]]-1);
//j 的位置一定注意,我们不管是什么,都一定求的是 sa[rk[i]] 和 sa[rk[i - 1]] 的最长公共前缀,只是因为证明,我们才使用了 rk[i - 1]
while (i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) k++;
height[rk[i]]=k;
}
}

3. 例题

T1:[NOI2015]品酒大会

题目传送门 Luogu

题目传送门 AcWing

能很好的体现后缀数组的作用和使用方法。

在使用后缀数组后,我们将会得到所有后缀的排名 $sa[]$,以及所有后缀与前一名的最大公共前缀 $height[]$。

首先考虑后缀的最大公共前缀与 $r$ 的关系。

由 “证明1” 可得,如果 $height[i]<r$,则 $i$ 上面的和下面的不可能 $lcp$ 大于等于 $r$。

同时我们也可以得到结论:如果将所有分成几段,则段内一定都是 $r$ 相似。

怎样维护最大值?

可以维护最大值和次大值,乘起来即可。

由于有负数,也要维护最小值和次小值。

由于有些绕,所以得认真打。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;

typedef long long LL;
typedef pair<LL, LL> PLL;
const int N = 3e5 + 10;
const LL INF = 2e18;
char str[N];
int sa[N], n, height[N], fa[N], a[N];
int mx1[N], mx2[N], mn1[N], mn2[N], sz[N];
LL cnt, res = -INF;
PLL ans[N];
vector<int> com[N];

void get_sa(char *str, int *sa, int n)
{
static int c[N], x[N], y[N];
int m = 126, num;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i] = str[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[i]] --] = i;
for (int k = 1; k <= n; k <<= 1)
{
num = 0;
for (int i = n - k + 1; i <= n; ++ i) y[++ num] = i;
for (int i = 1; i <= n; ++ i)
if (sa[i] > k) y[++ num] = sa[i] - k;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[y[i]]] --] = y[i];
swap(x, y);
x[sa[1]] = 1, num = 1;
for (int i = 2; i <= n; ++ i)
if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = num;
else x[sa[i]] = ++ num;
if (num == n) return;
m = num;
}
}

void get_height(char *str, int *sa, int *height, int n)
{
static int rk[N];
for (int i = 1; i <= n; ++ i) rk[sa[i]] = i;
for (int i = 1; i <= n; ++ i)
{
if (rk[i] == 1) continue;
int j = sa[rk[i] - 1], k = max(0, height[rk[i - 1]] - 1);
while (i + k <= n && j + k <= n && str[i + k] == str[j + k]) k ++;
height[rk[i]] = k;
}
}

int find(int x)
{
if (fa[x] == x) return x;
return fa[x] = find(fa[x]);
}

PLL solve(int x)
{
for (auto i : com[x])
{
int a = i, b = i - 1;
a = find(a), b = find(b);
cnt -= 1LL * sz[a] * (sz[a] - 1) / 2 + 1LL * sz[b] * (sz[b] - 1) / 2;
if (mx1[a] > mx1[b]) mx2[a] = max(mx2[a], mx1[b]);
else if (mx1[a] < mx1[b]) mx2[a] = max(mx1[a], mx2[b]), mx1[a] = mx1[b];
else mx2[a] = mx1[a];
if (mn1[a] < mn1[b]) mn2[a] = min(mn2[a], mn1[b]);
else if (mn1[a] > mn1[b]) mn2[a] = min(mn1[a], mn2[b]), mn1[a] = mn1[b];
else mn2[a] = mn1[a];
sz[a] += sz[b], sz[b] = 0;
fa[b] = a;
cnt += 1LL * sz[a] * (sz[a] - 1) / 2;
res = max(res, max(1LL * mx1[a] * mx2[a], 1LL * mn1[a] * mn2[a]));
}
if (cnt == 0) return {0, 0};
else return {cnt, res};
}

int main()
{
scanf("%d%s", &n, str + 1);
for (int i = 1; i <= n; ++ i) scanf("%d", &a[i]);
for (int i = 1; i <= n; ++ i) fa[i] = i, sz[i] = 1;
get_sa(str, sa, n), get_height(str, sa, height, n);
for (int i = 1; i <= n; ++ i) mx1[i] = mn1[i] = a[sa[i]], mx2[i] = -1e9, mn2[i] = 1e9;
for (int i = 2; i <= n; ++ i) com[height[i]].push_back(i);
for (int i = n - 1; ~i; -- i) ans[i] = solve(i);
for (int i = 0; i < n; ++ i) printf("%lld %lld\n", ans[i].first, ans[i].second);
return 0;
}

T2:[SDOI2016]生成魔咒

题目传送门 Luogu

题目传送门 AcWing

首先容易得到,所有后缀的所有前缀集合就是所有子串的集合。

证明3:在最长公共前缀内的,前面都出现过;在外面的,前面都没有出现过。

前一条易证。

如果在外面的某一个前缀在前面出现过,但是没有 $i - 1$ 出现过,$lcp(i,i-1)< lcp(i,k)$,则与已知矛盾。

至于从后面加,我们可以将序列翻转,最后再翻转输出即可。

然后,我们删除最后一个字符的时候,直接将它所在的后缀删除就是了,然后将它的下面的 $height$ 维护为 $\min(height[now], height[nxt])$,可以 $O(1)$ 维护。总个数也可以 $O(1)$ 维护。

此外,此题还有 $O(n)$ 的在线做法,可能会在 SAM 中讲解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <unordered_map>
using namespace std;

typedef long long LL;
const int N = 1e5 + 10;
int str[N];
int sa[N], n, m, height[N], rk[N];
unordered_map<int, int> dis;
LL ans[N], res;
int d[N], u[N];

void get_sa(int *str, int *sa, int n)
{
static int c[N], x[N], y[N];
int num;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i] = str[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[i]] --] = i;
for (int k = 1; k <= n; k <<= 1)
{
num = 0;
for (int i = n - k + 1; i <= n; ++ i) y[++ num] = i;
for (int i = 1; i <= n; ++ i)
if (sa[i] > k) y[++ num] = sa[i] - k;
for (int i = 1; i <= m; ++ i) c[i] = 0;
for (int i = 1; i <= n; ++ i) c[x[i]] ++;
for (int i = 1; i <= m; ++ i) c[i] += c[i - 1];
for (int i = n; i; -- i) sa[c[x[y[i]]] --] = y[i];
swap(x, y);
x[sa[1]] = 1, num = 1;
for (int i = 2; i <= n; ++ i)
if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = num;
else x[sa[i]] = ++ num;
if (num == n) return;
m = num;
}
}

void get_height(int *str, int *sa, int *height, int n)
{
for (int i = 1; i <= n; ++ i) rk[sa[i]] = i;
for (int i = 1; i <= n; ++ i)
{
if (rk[i] == 1) continue;
int j = sa[rk[i] - 1], k = max(0, height[rk[i - 1]] - 1);
while (i + k <= n && j + k <= n && str[i + k] == str[j + k]) k ++;
height[rk[i]] = k;
}
}

int discrete(int x)
{
if (dis.count(x) == 0) dis[x] = ++ m;
return dis[x];
}

int main()
{
cin >> n;
for (int i = 1; i <= n; ++ i) scanf("%d", &str[i]);
for (int i = 1; i <= n; ++ i) str[i] = discrete(str[i]);
reverse(str + 1, str + n + 1);
get_sa(str, sa, n), get_height(str, sa, height, n);
for (int i = 1; i <= n; ++ i) res += n - sa[i] + 1 - height[i];
for (int i = 1; i <= n; ++ i) u[i] = i - 1, d[i] = i + 1;
u[n + 1] = n, d[0] = 1;
for (int i = 1; i <= n; ++ i)
{
ans[i] = res;
int x = rk[i], y = d[x];
res -= n - sa[x] + 1 - height[x] + n - sa[y] + 1 - height[y];
height[y] = min(height[y], height[x]);
res += n - sa[y] + 1 - height[y];
u[d[x]] = u[x], d[u[x]] = d[x];
}
for (int i = n; i; -- i) printf("%lld\n", ans[i]);
return 0;
}