0%

ACM笔记 - 利用FFT求卷积(求多项式乘法)

卷积

给定向量:\(a=(a_0,a_1,...,a_{n-1})\)\(b=(b_0,b_1,...,b_{n-1})\)

向量和:\(a+b=(a_0+b_0,a_1+b_1,...,a_{n-1}+b_{n-1})\) 数量积(内积、点积):\(a·b=a_0b_0+a_1b_1+...+a_{n-1}b_{n-1}\) 卷积\(a \otimes b=(c_0,c_1,...,c_{2n-2})\),其中\(c_k=\sum_{i+j=k}(a_ib_j)\)

例如:\(c_{n-1}=a_0b_{n-1}+a_1b_{n-2}+...+a_{n-2}b_1+a_{n-1}b_0\)

卷积的最典型的应用就是多项式乘法(多项式乘法就是求卷积)。以下就用多项式乘法来描述、举例卷积与DFT。

关于多项式

对于多项式\(A(x)\),系数为\(a_i\),设最高非零系数为\(a_k\),则其次数就是\(k\),记作\(degree(A)=k\)。任何大于\(k\)的整数都是\(A(x)\)次数界

多项式的系数表达方式\(A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}=\sum^{n-1}_{i=0} a_jx^j\)(次数界为\(n\))。 则多项式的系数向量即为\(a=(a_0,a_1,...,a_{n-1})\)。 多项式的点值表达方式\(\{(x_0,y_0),(x_1,y_1),...,(x_{n-1},y_{n-1})\}\),其中\(x_k\)各不相同,\(y_k=A(x_k)\)

离散傅里叶变换(DFT)

离散傅里叶变换(Discrete Fourier Transform,DFT)。在信号处理很重要的一个东西,这里物理意义以及其他应用暂不予理睬。在多项式中,DFT就是系数表式转换成点值表示的过程。

快速傅里叶变换(FFT)

快速傅里叶变换(Fast Fourier Transformation,FFT):快速计算DFT的算法,能够在\(O(n\log n)\)的时间里完成DFT。FFT只是快速的求DFT的方法罢了,不是一个新的概念。 在ACM-ICPC竞赛中, FFT算法常被用来为多项式乘法加速。FFT与其逆变换IFFT类似,稍微加几行代码。

求FFT要用到复数。一个简单的模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct Complex // 复数
{
double r, i;
Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
Complex operator +(const Complex &b) {
return Complex(r + b.r, i + b.i);
}
Complex operator -(const Complex &b) {
return Complex(r - b.r, i - b.i);
}
Complex operator *(const Complex &b) {
return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
}
};
递归实现FFT模板:来源
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Complex* RecursiveFFT(Complex a[], int n)//n表示向量a的维数
{
if(n == 1)
return a;
Complex wn = Complex(cos(2*PI/n), sin(2*PI/n));
Complex w = Complex(1, 0);
Complex* a0 = new Complex[n >> 1];
Complex* a1 = new Complex[n >> 1];
for(int i = 0; i < n; i++)
if(i & 1) a1[(i - 1) >> 1] = a[i];
else a0[i >> 1] = a[i];
Complex *y0, *y1;
y0 = RecursiveFFT(a0, n >> 1);
y1 = RecursiveFFT(a1, n >> 1);
Complex* y = new Complex[n];
for(int k = 0; k < (n >> 1); k++)
{
y[k] = y0[k] + w*y1[k];
y[k + (n >> 1)] = y0[k] - w*y1[k];
w = w*wn;
}
delete a0;
delete a1;
delete y0;
delete y1;
return y;
}
非递归实现。模板:(来源忘了)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)  
{
int i, j, k;
for (i = 1, j = len / 2;i < len - 1;i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
if (j < k)j += k;
}
}
void fft(Complex y[], int len, int on) //FFT:on=1; IFFT:on=-1
{
change(y, len);
for (int h = 2;h <= len;h <<= 1)
{
Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for (int j = 0;j < len;j += h)
{
Complex w(1, 0);
for (int k = j;k < j + h / 2;k++)
{
Complex u = y[k];
Complex t = w*y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w*wn;
}
}
}
if (on == -1)
for (int i = 0;i < len;i++)
y[i].r /= len;
}

利用FFT求卷积

普通的计算多项式乘法的计算,时间复杂度\(O(n^2)\)。而FFT先将多项式点值表示(\(O(n\log n)\)),在\(O(n)\)下完成对点值的乘法,再以\(O(n\log n)\)完成IFFT,重新得到系数表示。

步骤一(补0)

在两个多项式前面补0,得到两个2n次多项式,设系数向量分别为\(v_1\)\(v_2\)

步骤二(求值)

使用FFT计算\(f_1=DFT(v_1)\)\(f_2=DFT(v_2)\)。则\(f_1\)\(f_2\)为两个多项式在\(2n\)次单位根处的取值(即点值表示)。

步骤三(乘法)

\(f_1\)\(f_2\)每一维对应相乘,得到\(f\),代表对应输入多项式乘积的点值表示。

步骤四(插值)

使用IFFT计算\(v=IDFT(f)\),其中\(v\)就是乘积的系数向量。

综上

\(a \otimes b=IDFT_{2n}(DFT_{2n}(a)·DFT_{2n}(b))\),即:\(a \otimes b=DFT^{-1}_{2n}(DFT_{2n}(a)·DFT_{2n}(b))\)

\(A(x1)\otimes B(x2)\)

1
2
3
4
5
6
fft(x1, len, 1);
fft(x2, len, 1);
for (int i = 0;i < len;i++) {
x[i] = x1[i] * x2[i];
}
fft(x, len, -1);

#例题 ##1.2016 acm香港网络赛 A题 A+B Problem 网上的代码(当时没保留出处。。。)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#include <algorithm>
#include <cstring>
#include <string.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <cstdio>
#include <cmath>

#define LL long long
#define N 200005
#define INF 0x3ffffff

using namespace std;

const double PI = acos(-1.0);


struct Complex // 复数
{
double r, i;
Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
Complex operator +(const Complex &b)
{
return Complex(r + b.r, i + b.i);
}
Complex operator -(const Complex &b)
{
return Complex(r - b.r, i - b.i);
}
Complex operator *(const Complex &b)
{
return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
}
};

void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)
{
int i, j, k;
for (i = 1, j = len / 2;i < len - 1;i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
if (j < k)j += k;
}
}
void fft(Complex y[], int len, int on) //DFT和FFT
{
change(y, len);
for (int h = 2;h <= len;h <<= 1)
{
Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for (int j = 0;j < len;j += h)
{
Complex w(1, 0);
for (int k = j;k < j + h / 2;k++)
{
Complex u = y[k];
Complex t = w*y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w*wn;
}
}
}
if (on == -1)
for (int i = 0;i < len;i++)
y[i].r /= len;
}


const int M = 50000; // a数组所有元素+M,使a[i]>=0
const int MAXN = 800040;

Complex x1[MAXN];
int a[MAXN / 4]; //原数组
long long num[MAXN]; //利用FFT得到的数组
long long tt[MAXN]; //统计数组每个元素出现个数

int main()
{
int n = 0; // n表示除了0之外数组元素个数
int tot;
scanf("%d", &tot);
memset(num, 0, sizeof(num));
memset(tt, 0, sizeof(tt));

int cnt0 = 0; //cnt0 统计0的个数
int aa;

for (int i = 0;i < tot;i++)
{
scanf("%d", &aa);
if (aa == 0) { cnt0++;continue; } //先把0全删掉,最后特殊考虑0
else a[n] = aa;
num[a[n] + M]++;
tt[a[n] + M]++;
n++;
}

sort(a, a + n);
int len1 = a[n - 1] + M + 1;
int len = 1;

while (len < 2 * len1) len <<= 1;

for (int i = 0;i < len1;i++) {
x1[i] = Complex(num[i], 0);
}
for (int i = len1;i < len;i++) {
x1[i] = Complex(0, 0);
}
fft(x1, len, 1);

for (int i = 0;i < len;i++) {
x1[i] = x1[i] * x1[i];
}
fft(x1, len, -1);

for (int i = 0;i < len;i++) {
num[i] = (long long)(x1[i].r + 0.5);
}

len = 2 * (a[n - 1] + M);

for (int i = 0;i < n;i++) //删掉ai+ai的情况
num[a[i] + a[i] + 2 * M]--;
/*
for(int i = 0;i < len;i++){
if(num[i]) cout<<i-2*M<<' '<<num[i]<<endl;
}
*/
long long ret = 0;

int l = a[n - 1] + M;

for (int i = 0;i <= l; i++) //ai,aj,ak都不为0的情况
{
if (tt[i]) ret += (long long)(num[i + M] * tt[i]);
}

ret += (long long)(num[2 * M] * cnt0); // ai+aj=0的情况

if (cnt0 != 0)
{
if (cnt0 >= 3) { //ai,aj,ak都为0的情况
long long tmp = 1;
tmp *= (long long)(cnt0);
tmp *= (long long)(cnt0 - 1);
tmp *= (long long)(cnt0 - 2);
ret += tmp;
}
for (int i = 0;i <= l; i++)
{
if (tt[i] >= 2) { // x+0=x的情况
long long tmp = (long long)cnt0;
tmp *= (long long)(tt[i]);
tmp *= (long long)(tt[i] - 1);
ret += tmp * 2;
}
}
}

printf("%lld\n", ret);

return 0;
}