fft算法

fft(快速傅里叶变换)是可以在nlogn时间计算向量卷积的算法
向量卷积即计算向量f
f[i] = Σ(g[k] * h[i – k])
对于一个普通的算法 需要用n^2的时间进行计算

fft的思想是对于一个n元向量g 可以表示为n-1次多项式
这个多项式在x=1时=g[1], x=2时=g[2]….
其实两个向量的卷积 就是这两个多项式的乘积
对于一个k次多项式 需要k+1个点就可以确定
我们把g和h转为点值表示 并对他们的y值分别数乘 就可以得到f的点值表示
然后把点值表示的f转换为多项式形式 就可以得出f[]
但是注意到把多项式转为点值表示 及把点值表示转换为多项式表示都需要O(n^2)的时间
而中间的数乘步骤只需要O(n)

fft主要是解决在O(nlogn)的时间对多项式和点值的互换
先解决把多项式转为点值表示
现在要把g转为点值表示 设g为l-1次多项式 且l=2^n(如不足可以补0)
先选取采样点 我们选取e^(2πi/l * j) 其中i为sqrt(i), e为自然对数
因为:
1.e^(ix)=cosx+isinx
2.当x=π时,得到欧拉恒等式 e^(iπ)=-1 ==> e^(2iπ)=1
我们称wn=e^(2iπ/n)为n次单位复根
因为wn^n = 1
复平面上的N次单位复根N条半径可以理解为等分一个半径为1的圆
g(x) = a1 + a2x + a3x^2 + … + anx^(n-1)

g0(x) = a1 + a3x + a5x^2…
g1(x) = a2 + a4x + a6x^2…

可以发现
g0(x^2) = a1 + a3x^2 + a5x^4….
g1(x^2) = a2 + a4x^2 + a6x^4….
=>g(x) = g0(x^2) + x * g1(x^2)
设g[i][j] = g(e^(2iπ/(2^i) * j))

我们要求的是g[n][j] 这就是以wn^j为x的对应的y值 即g的点值表示
g[n][j] = g[n – 1][j] + wn^j * g[n – 1][j + (1 << (n - 1))] g[n][j + (1 << (n - 1))] = g[n - 1][j] - wn^j * g[n - 1][j + (1 << (n - 1))] (0~(1 << (n - 1))为g0 ((1 << (n - 1)) ~ (1 << n)) 为g1 证明: g[n][j] = g[n - 1][j] + wn^j * g[n - 1][j + (1 << (n - 1))]: 对于g[n - 1] wn[n - 1] = wn[n]^2 (指数为1/2) 所以g[n - 1][j] = g0(j^2)... g[n][j + (1 << (n - 1))] = g[n - 1][j] - wn^j * g[n - 1][j + (1 << (n - 1))]: 对于n - 1 只要(1 << (n - 1))就是一个循环 所以g0(j + (1 << (n - 1))) = g0(j) 有wn^j = -wn^(j + (1 << (n - 1))) 所以成立 关于非递归 就是直接从g[0]开始做 可以发现g[0][i] = a[bitre(i)] bitre(i) = 把i的二进制位翻转 然后推上去即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void fft(complex y[], const complex a[], int n)
{
  int len = 1 << n;
  for (int i = 0; i < len; ++i) y[i] = a[bitre(i, n)];
  for (int i = 1; i <= n; ++i)
  {
    int d = 1 << i;
    complex wn = complex(cos(2 * pi / d), sin(2 * pi / d));
    for (int j = 0; j < len; j += d)
    {
      complex w = complex(1, 0);
      for (int k = j; k < j + d / 2; ++k)
      {
        complex t = y[k], p = y[k + d / 2] * w;
        y[k] = t + p, y[k + d / 2] = t - p;
        w = w * wn;
      }
    }
  }
}

这样就求出了g和h的点值表示 数乘后进行ifft
可以证明 只要把上面程序改为以下即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
void fft(complex y[], const complex a[], int n, int v)
{
  int len = 1 << n;
  for (int i = 0; i < len; ++i) y[i] = a[bitre(i, n)];
  for (int i = 1; i <= n; ++i)
  {
    int d = 1 << i;
    complex wn = complex(cos(2 * pi / d * v), sin(2 * pi / d * v));
    for (int j = 0; j < len; j += d)
    {
      complex w = complex(1, 0);
      for (int k = j; k < j + d / 2; ++k)
      {
        complex t = y[k], p = y[k + d / 2] * w;
        y[k] = t + p, y[k + d / 2] = t - p;
        w = w * wn;
      }
    }
  }
  if (v == -1)
    for (int i = 0; i < len; ++i) y[i].real /= len;
}

v传入1为转点值
v传入-1为ifft

以下是bzoj2194

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <vector>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <stack>
#include <iostream>
#include <algorithm>
#include <string>
#include <functional>
#include <sstream>
using namespace std;
typedef long long ll;
typedef long double ld;
 
const int maxn = 1010000;
const double pi = acos(-1);
 
struct complex
{
  double real, imag;
  complex(double r = 0, double i = 0) : real(r), imag(i) {}
}a[maxn], b[maxn], t1[maxn], t2[maxn], ans[maxn];
 
complex operator+(complex a, complex b) { return complex(a.real + b.real, a.imag + b.imag); }
complex operator-(complex a, complex b) { return complex(a.real - b.real, a.imag - b.imag); }
complex operator*(complex a, complex b) { return complex(a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real); }
 
int n, m, len, l;
 
int bitre(int a, int n)
{
  int ans = 0;
  for (int i = 0; i < n; ++i)
    ans += (1 << i) * ((a & (1 << (n - i - 1))) ? 1 : 0);
  return ans;
}
void fft(complex y[], complex a[], int n, int rev)
{
  int len = 1 << n;
  for (int i = 0; i < len; ++i) y[i] = a[bitre(i, n)];
  for (int i = 1; i <= n; ++i)
  {
    int d = 1 << i;
    complex wn = complex(cos(2 * pi / d * rev), sin(2 * pi / d * rev));
    for (int j = 0; j < len; j += d)
    {
      complex w = complex(1, 0);
      for (int k = j; k < j + d / 2; ++k)
      {
        complex t = y[k], p = y[k + d / 2] * w;
        y[k] = t + p, y[k + d / 2] = t - p;
        w = w * wn;
      }
    }
  }
  if (rev == -1)
    for (int i = 0; i < len; ++i) y[i].real /= len;
}
 
int main()
{
  scanf("%d", &n);
  for (int i = 0; i < n; ++i)
    scanf("%lf%lf", &a[n - i - 1].real, &b[i].real);
  for (; (1 << l) <= n * 3; ++l);
  len = 1 << l;
  fft(t1, a, l, 1);
  fft(t2, b, l, 1);
  for (int i = 0; i < len; ++i) t1[i] = t1[i] * t2[i];
  fft(ans, t1, l, -1);
  for (int i = n - 1; i >= 0; --i) printf("%d\n", (int)floor(ans[i].real + 0.5));
}

发表评论