P3321

比较难想,但其实是一个套路。

1. 题意

给定匹配串 $a$ 和原串 $b$,要求像 KMP 一样匹配,但是有通配符(指和每一个都可以匹配),给出所有的起点可以匹配。

$|a|,|b|\leq 10 ^ 5 $。

2. 思路

(以下字符串默认从 0 开始)

首先,肯定不是暴力枚举每一个通配符的匹配字符。

其中,对于一个字符串 $s$,构造

表示如果是通配符的话,就是 0,否则就是原字符本身。

我们现在假设要求 $k$ 这个位置能否匹配。

我们考虑构造:

那么,$[x^i]H(x)$ 为 0 的话,有 3 种情况:

  1. $A(i) = 0$
  2. $B(i) = 0$
  3. $A(i) = B(i + x - 1)$

可以发现,这三种情况正好对应的有通配符的情况下的匹配。

由于每一项非负,所以只要有一项不是 0,所以整个就不是 0。

那么,我们展开一下:

那么,我们就只需要求出所有的 $H(x)$,每一位都是 $H_k$,只需要统计 0 的个数就可以了。

这个很明显是一个差相等的会放到一个 $H_k$ 中,根据套路,我们把它翻转一个。

直接 NTT 就可以了。注意每一项都要 NTT,而不是一次 NTT 直接计算。

注意有人卡 998244353,直接把原根换成 5 或者模数换为 167772161 就可以了。但是确实可以被卡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
vector<int> Match(char *s1, char *s2)
{
vector<int> mat;
int m = strlen(s1), n = strlen(s2);
reverse(s1, s1 + m);
static LL f[N], g[N], h[N], a[N], b[N];
int bit = 0;
while ((1 << bit) < (n + m + 1)) bit ++;
int tot = 1 << bit;
for (int i = 0; i < tot; ++ i) h[i] = f[i] = g[i] = 0;
for (int i = 0; i < m; ++ i)
if (s1[i] == '*') f[i] = 0;
else f[i] = s1[i] - 'a' + 1;
for (int i = 0; i < n; ++ i)
if (s2[i] == '*') g[i] = 0;
else g[i] = s2[i] - 'a' + 1;
//H(i - j) += F(i) * G(j) * (F(i) - G(j)) ^ 2
//H(i + j - m - 1) += F(i) * G1(m - j - 1) * (F(i) - G1(j)) ^ 2
/*for (int i = 0; i < tot; ++ i)
f[i] = (qpow(f[i], 3) * g[i] % Mod - qpow(f[i] * g[i] % Mod, 2) * 2 % Mod + qpow(g[i], 3) * f[i] % Mod + Mod) % Mod;*/
for (int i = 0; i < tot; ++ i) a[i] = f[i] * f[i] * f[i];
for (int i = 0; i < tot; ++ i) b[i] = g[i];
NTT(a, bit, 1), NTT(b, bit, 1);
for (int i = 0; i < tot; ++ i) h[i] = (h[i] + a[i] * b[i]) % Mod;

for (int i = 0; i < tot; ++ i) a[i] = f[i] * f[i];
for (int i = 0; i < tot; ++ i) b[i] = g[i] * g[i];
NTT(a, bit, 1), NTT(b, bit, 1);
for (int i = 0; i < tot; ++ i) h[i] = (h[i] + (Mod - 2) * a[i] % Mod * b[i]) % Mod;

for (int i = 0; i < tot; ++ i) a[i] = f[i];
for (int i = 0; i < tot; ++ i) b[i] = g[i] * g[i] * g[i];
NTT(a, bit, 1), NTT(b, bit, 1);
for (int i = 0; i < tot; ++ i) h[i] = (h[i] + a[i] * b[i]) % Mod;

NTT(h, bit, -1);
// for (int i = m - 1; i < n; ++ i) cout << h[i] << ' ';
// puts("");
for (int i = m - 1; i < n; ++ i)
if (h[i] == 0) mat.push_back(i - m + 1);
return mat;
}