FFT 详解

前方高能!

0. 前言

本节是数学之中算非常难的模板题了。

注意,如果有没学过的知识点,请记住结论即可,不必深挖使自己绕晕。

1. 解决问题

求两个高次多项式的乘积。

A(x) 多项式为 n 次,B(x) 多项式为 m 次(nm)。

朴素复杂度为 O(nm),该算法可以在 O(nlogn) 的时间求出。

2. 主要方法

1)前置知识

A. 点值表示

任意一个 n 次多项式都可以用任意 n+1 个点来表示,其中 n+1 个点都在函数上。

证明不难,就是 n+1n+1 元一次方程有唯一解。

经过“范德蒙矩阵”的行列式推导,当且仅当 xixj(ij) 时原方程有唯一解。

B. 复数的基本运算

首先,令 i=1

所以,所有的复数都可以表示为 a+bi

复数的加法:(a+bi)+(c+di)=(a+c)+(b+d)i

用向量的角度看,就是两个向量的合成。

复数的乘法:(a+bi)×(c+di)=(acbd)+(ad+bc)i

从向量的角度看,乘积的模就为原来的向量的模的乘积。

乘积的角度为 θ=θ1+θ2θ 为向量与 x 轴正半轴的夹角。

C. 复数域上的单位根

画一个单位圆。

将该圆划分为 n 份,从 x 轴正半轴逆时针取了 k 份后,终边所表示的向量记为 ωnk

n 次单位根为:ωnk(k[0,n1])

它有几个性质:

  1. ij,ωniωnj
  2. ωnk=cos2kπn+sin2kπni
  3. ωn0=ωnn=1
  4. ω2n2k=ωnk
  5. ωnk+n2=ωnk

同样,具有上述几个性质的点也是 n 次单位根。

2)核心:点值表示与系数表示的转换

在 FFT 中,我们一共取 n+m+1 个点。

怎样求出 C(x) 的表达式?

我们设定 n+m+1 个点,然后求出每一个点的函数值,就求出了点值表示。

怎样转换为系数表示呢?

首先,我们考虑取哪些点。

对于 A(x) 来说,我们取 ωnk(k[0,n1]) n 个横坐标。

我们还是首先考虑怎样从系数表示转换为点值表示。

然后,我们按次数分类,分为奇数次和偶数次。

A(x)=a0+a1x+a2x2++an1zn1

假设 n 为偶数。

A1(x)=a0+a2x+a4x2++an2xn21

A2(x)=a1+a3x+a5x2++an1xn21

所以,我们惊喜地发现:A(x)=A1(x2)+xA2(x2)

然后我们将 ωnk 代入,得到:

如果 k[0,n21],原式可以写成:

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)=A1(ωn2k)+ωnkA2(ωn2k)

如果 k[n2,n1],原式可以写成:

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)=A1(ωn2kn)ωnkn2A2(ωn2kn)

这样,我们就可以在 O(nlogn) 的时间内从系数表示转换到点值表示。

下面,我们进行逆变化:从点值表示转换到系数表示。

假设最终的答案为 A(x)=c0+c1x++cn1xn1

结论:

ck=i=0n1yi(ωnk)in

其中,yi 表示最终得到的纵坐标。

首先,我们假设该结论成立。

再定义 B(x)=y0+y1x+y2x2++yn1xn1

不难发现 ci=B(ωni)

所以,我们相当于把 B 从系数表示法转换为点值表示法,就在做一次前面的即可。

可以发现(?),有一个负号不影响答案。

好,下面我们来证明该结论。

证明:(从结论反推)

nck=i=0n1yi(ωnk)i=i=0n1(j=0n1cj(ωni)j)(ωnk)i=i=0n1j=0n1cj(ωni)j(ωnk)i=i=0n1j=0n1cj(ωnjk)i=j=0n1cj(i=0n1(ωnjk)i)

在最里面的括号 (ωnjk)ij,k 都是常量,i 是变量,所以我们可以再构造 D(x)=i=0n1xi,原式就可以化成:

=j=0n1cjD(ωnjk)

接着,我们讨论 D(ωnx) 的取值:

x0:,首先,D(ωnx)=ωn0+ωnx+ωn2x++ωn(n1)x

其次,ωnxD(ωnx)=ωnx+ωn2x++ωn(n1)x+(ωnnx=0)

发现,这两项相等。

于是,(ωnx1)D(ωnx)=0,又 x0,所以 ωnx1,于是 D(ωnx)=0

x=0D(ωnx)=D(1)=n

所以,当 jk 时,原式为 0

带入原式,便可以得到:

=nck

证毕。

现在,我们讨论的是怎样求出 n 次单位根对应的点值。

首先,有前面推导的式子:

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)A(ωnk+n2)=A1(ωn2k)ωnkA2(ωn2k)

在实践中,递归的常数较大,我们使用迭代。

首先,我们发现,如果 in2,就可以得到 ai=bi+ωnici,如果 i>n2,就可以得到 ai=biωnin2ci,其中 bi,ci 分别表示左边和右边求出来的系数,分别对应偶数求出来的点值和奇数次方求出来的点值。

我们每一次求点值时,会递归并将偶数次项的放前面,将奇数次项的放后面,然后计算出前后的 n 次方根的数值,然后将答案合并即可。

来看一张图。

(红色 a 表示系数的长度,蓝色 b 表示已经求出的点值的合并,黑色表示点值的转移)

首先,我们要预处理第一层,即 b 的最上面一层。

我们找一下规律:第零个 0,第一个 4,……

可以发现,答案求出的就是二进制的翻转,这个例子中是指三位的异或。

我们考虑刚才的结论怎么证明。

其实很简单:如果最后一位为 1,则放在右边,也就是当前的最高位为 1,也就是翻转了。

假设 bit 位的翻转,我们怎样求出翻转结果呢?

记要翻转的数为 i,结果为 rev(i)

首先,我们将最后一位不看,即为 i/2,然后将剩下的翻转,为 rev(i/2),然后在最前面补上前面的最后一位。

所以,我们可以得到一个递推公式:rev[i] = (rev[i >> 1]) | (i & 1) << (bit - 1)

至此,所有的推导全部结束。

大功告成!!!

最后的最后,我们总结一下整体的思路。

首先,将两个式子分别从系数表示转换为点值表示,其中要使用到递归(或迭代)求点值,最后用两边合并为当前,时间复杂度为 O(nlogn)

然后,用点值表示将两个式子乘起来,直接得到 O(n)

然后,我们将得到的式子用一个证明,即为点值做系数,ωnk 做自变量,就可以求出 nck 的值,这再做一次前面的正向即可。

最后,除以 n 输出即可。

至此,上代码。可能还是有一些地方没有讲清楚,请读者看代码吧(逃

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;

const int N = 3e6 + 10;
const double PI = acos(-1);

struct Complex{
double x, y;
const Complex operator +(const Complex &t)const{
return (Complex){x + t.x, y + t.y};
}
const Complex operator -(const Complex &t)const{
return (Complex){x - t.x, y - t.y};
}
const Complex operator *(const Complex &t)const{
return (Complex){x * t.x - y * t.y, x * t.y + y * t.x};
}
}a[N], b[N];
int tot, bit, rev[N];

void FFT(Complex a[], int inv)
{
for (int i = 0; i < tot; ++ i)
if (rev[i] < i) swap(a[rev[i]], a[i]);

for (int mid = 1; mid < tot; mid <<= 1)
{
Complex w1 = (Complex){cos(PI / mid), sin(inv * PI / mid)};
for (int i = 0; i < tot; i += mid * 2)
{
Complex now = (Complex){1, 0};
for (int j = 0; j < mid; j ++, now = now * w1)
{
Complex x = a[i + j], y = now * a[i + j + mid];
a[i + j] = x + y, a[i + j + mid] = x - y;
}
}
}
}

int main()
{
int n, m;
scanf("%d %d", &n, &m);
for (int i = 0; i <= n; ++ i) scanf("%lf", &a[i].x);
for (int i = 0; i <= m; ++ i) scanf("%lf", &b[i].x);

while ((1 << bit) < n + m + 1) bit ++;
tot = 1 << bit;

for (int i = 0; i < tot; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);

FFT(a, 1);FFT(b, 1);
for (int i = 0; i < tot; ++ i) a[i] = a[i] * b[i];
FFT(a, -1);

for (int i = 0; i <= n + m; ++ i) printf("%d ", int((a[i].x / tot) + 0.5));

return 0;
}

Gitalking ...