BZOJ3473 字符串

经典 SAM 题目,似乎有很简洁的做法,但没写……

题意:给定 $n$ 个字符串,对于每个字符串,求有多少个本质不同的子串满足在至少 $k$ 个串中出现过。$n, k, \sum |S|\leq 10 ^ 5$。

看到 $n$ 个串的本质不同统计,先建广义 SAM、

我们现在要干的事情是统计每一个子串在多少个字符串中出现过。组合一道广义 SAM 和普通 SAM 略有不同,如果直接打标记按照 parent 树上传的话,会出现重复的情况。于是我们使用线段树合并,这样就可以把相同位置重复的去掉了。于是时空复杂度都是 $O(\sum|S|\log \sum|S|)$,我们可以得到每一个子串在 $n$ 个字符串的那几个所覆盖,可以接受。

这里可以不用线段树合并,直接大力标记,遇到标记过的就跳过,这样可以平衡规划证得时间复杂度为 $O(\sum |S|\sqrt{\sum |S|})$,仍然可以通过。

然后考虑如何统计每一个串的答案。直接在广义 SAM 上跑匹配,因为一个字符串一旦合法,他 parent 的父亲都合法,我们不需要向上再跳,直接假设当前节点的最大值即可。这里是线性的。

于是可以在 $O(\sum |S| \log \sum |S|)$ 可以解决,注意一下应该在哪个节点加上该位置标记的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
struct Segment_Tree {
struct Node {
int lc, rc, cnt;
} tr[N << 6];
int cnt;

void insert(int &rt, int l, int r, int pos)
{
if (!rt) rt = ++ cnt;
if (l == r) return void(tr[rt].cnt = 1);
int mid = (l + r) >> 1;
if (pos <= mid) insert(tr[rt].lc, l, mid, pos);
else insert(tr[rt].rc, mid + 1, r, pos);
tr[rt].cnt = tr[tr[rt].lc].cnt + tr[tr[rt].rc].cnt;
}

int merge(int p, int q, int l = 1, int r = n)
{
if (!p || !q) return p | q;
int cur = ++ cnt, mid = (l + r) >> 1;
if (l == r) return tr[cur].cnt = tr[p].cnt | tr[q].cnt, cur;
tr[cur].lc = merge(tr[p].lc, tr[q].lc, l, mid);
tr[cur].rc = merge(tr[p].rc, tr[q].rc, mid + 1, r);
tr[cur].cnt = tr[tr[cur].lc].cnt + tr[tr[cur].rc].cnt;
return cur;
}
} seg;

struct SAM {
struct Node {
int ch[26], len, fa;
} tr[N << 1];
int rt[N << 1], sz[N << 1];
int tot;
std::vector<int> g[N << 1];
SAM() : tot(1) {}

int extend(int ls, int c, int col)
{
if (tr[ls].ch[c]) {
int p = ls, q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1)
return seg.insert(rt[q], 1, n, col), q;
int nq = ++ tot;
seg.insert(rt[nq], 1, n, col);
tr[nq] = tr[q], tr[nq].len = tr[p].len + 1;
for (; p && tr[p].ch[c] == q; p = tr[p].fa) tr[p].ch[c] = nq;
return tr[q].fa = nq, nq;
}
int p = ls, np = ++ tot;
tr[np].len = tr[p].len + 1;
seg.insert(rt[np], 1, n, col);
for (; p && !tr[p].ch[c]; p = tr[p].fa) tr[p].ch[c] = np;
if (!p) tr[np].fa = 1;
else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1)
seg.insert(rt[q], 1, n, col), tr[np].fa = q;
else {
int nq = ++ tot;
seg.insert(rt[nq], 1, n, col);
tr[nq] = tr[q], tr[nq].len = tr[p].len + 1;
for (; p && tr[p].ch[c] == q; p = tr[p].fa) tr[p].ch[c] = nq;
tr[np].fa = tr[q].fa = nq;
}
}
return np;
}

void dfs(int x)
{
for (int v : g[x])
dfs(v), rt[x] = seg.merge(rt[x], rt[v]);
sz[x] = seg.tr[rt[x]].cnt;
}

void work()
{
for (int i = 2; i <= tot; ++ i) g[tr[i].fa].push_back(i);
dfs(1);
}
} sam;

int main()
{
std::cin.tie(0)->sync_with_stdio(false);
std::cin >> n >> k;
for (int i = 1, ls; i <= n; ++ i)
{
std::cin >> s[i];
ls = 1;
for (char c : s[i]) ls = sam.extend(ls, c - 'a', i);
}
sam.work();
// exit(0);
for (int i = 1; i <= n; ++ i)
{
int p = 1;
long long res = 0;
for (char c : s[i])
{
p = sam.tr[p].ch[c - 'a'];
while (p != 1 && sam.sz[p] < k) p = sam.tr[p].fa;
if (sam.sz[p] >= k) res += sam.tr[p].len;
}
std::cout << res << ' ';
}
return 0;
}