daruma3940の日記

理解や文章に間違い等あればどんなことでもご指摘お願いします

FFT

FFTについて教わったことをまとめる。

N=2^nとして、

\begin{eqnarray}
\tilde{f}(y) = \sum_{x=0}^{N-1}\exp\left(\frac{2\pi xy}{N}i\right) f(x)
\end{eqnarray}
(y = 0, 1, \dots N-1)を求める為の方法。

http://nbviewer.jupyter.org/gist/genkuroki/d7328478c1188f876052fce859e91b40
こういう関数の微分を誤差が小さい方法で求めるときに使ったりする。


x,yを2進数展開したものを
\[x=\sum_{m=1}^{n} x_m 2^{n-m} =x_1 2^{n-1}+x_2 2^{n-2}+.....+x_n=x_1 x_2....x_n\]
\[y=\sum_{m=1}^{n} y_m 2^{n-m}=y_1 2^{n-1}+y_2 2^{n-2}+.....+y_n = y_1 y_2....y_n\]
のようにして書く記法をとる。
yの2進数展開をひっくり返した数列を Y_mとして書く。
y=\sum_{m=1}^{n} Y_m 2^{m-1}=Y_n Y_{n-1}....Y_1=y_1 y_2.... y_n

今後q 2^{-1}+r 2^{-2}+k 2^{-3}...0.qrk...と書くことにする。

2\piの整数倍は指数関数の中で意味はなくなるのでxy/Nの小数部分だけを考える.

\[\frac{xy}{N}=\frac{(x_1 x_2 ... x_n )(y_1 y_2 ... y_n)}{(2^n)} \]
\[=(x_1 x_2 ... x_n)(0.y_1 y_2...y_n)\]
\[ =(x_1 2^{n-1}+x_2 2^{n-2}+.....+x_n)*(y_1 2^{-1}+y_2 2^{-2}+...+y_n 2^{-n})\]
\[ =(Integer)+x_1 (0.y_n)+ x_2 (0.y_{n-1}y_n)+...+x_n(0.y_1...y_n)\]
\[ =(Integer) +x_1(0.Y_1)+x_2(0.Y_1Y_2)+...+x_n(0.Y_n...Y_1) \]

関数 f(x), \tilde{f}(y)f(x_1x_2 \dots x_n), \tilde{f}(y_1 y_2 \dots y_n)と書くことにすると、
\tilde{f}(y)

\tilde{f}(y_1 y_2 \dots y_n) = \nonumber \\
    \sum_{x_n=0}^1 e^{2\pi i(x_n\times 0.Y_{n}\dots Y_{1})}
    \dots 
    \sum_{x_{2}=0}^1 e^{2\pi i(x_{2} \times 0.Y_2 Y_1)}
     \sum_{x_{1}=0}^1 e^{2\pi i(x_1\times 0.Y_1)} f(x_1 \dots x_n)
x_mごとに分けて書くことが出来る。
x_mについて見ると、 x_mから Y_mへの2点フーリエ変換とみることが出来るので
各桁の2点フーリエ変換を繰り返すことで \tilde{f}(y)は求まる。
x_1について具体的に見ると

    f_1(Y_1=0 x_2 \dots x_n) = f(x_1=0 x_2 \dots x_n) + f(x_1=1 x_2 \dots x_n)  \\
    f_1(Y_1=1 x_2 \dots x_n) = f(x_1=0 x_2 \dots x_n) - f(x_1=1 x_2 \dots x_n)
ここで e^{2\pi i\times (0.y_1)} = e^{\pi i\times y_1} = (-1)^{y_1}を用いた。

x_mについての和は

f_m(Y_1 \dots Y_{m-1} Y_m=0 x_{m+1} \dots x_n) = \\
f_{m-1}(Y_1 \dots Y_{m-1} x_m=0 x_{m+1} \dots x_n) + e^{2\pi i\times (0.0 Y_{m-1}\dots Y_1)}
f_{m-1}(Y_1 \dots Y_{m-1} x_m=1 x_{m+1} \dots x_n)\\
f_m(Y_1,\dots Y_{m-1} Y_m=1 x_{m+1} \dots x_n) = \\
f_{m-1}(Y_1 \dots Y_{m-1} x_m=0 x_{m+1} \dots x_n) - e^{2\pi i\times (0.0 Y_{m-1} \dots Y_1)}
f_{m-1}(Y_1 \dots Y_{m-1} x_m=1 x_{m+1} \dots x_n)\\

Y_1 \dots Y_{m-1}  x_{m+1} \dots x_nのパターンはたくさんある。

これをn番目まで繰り返すとf_n(Y_1\dots Y_n)が得られ、\tilde{f}(y)

\tilde{f}(y) = f_n(y_n \dots y_1)
として求まる。

実装としては
y_1 y_2 \dots y_nをひっくり返すところをちゃんとするために
gを

    g_m [y_{n-m+1} y_{n-m+2} \dots y_n x_{m+1}\dots x_n] = 
    f_m[Y_1 \dots Y_m x_{m+1} \dots x_n]
とし,

    g_m [ y_{n-m+1}=0 y_{n-m+2} \dots y_n x_{m+1} \dots x_n ] = \\
    g_{m-1}[ y_{n-m+2} \dots y_n x_{m}=0 x_{m+1} \dots x_n ] + e^{2\pi i\times (0.0y_{n-m+2}\dots y_n)}
    g_{m-1}[ y_{n-m+2} \dots y_n x_{m}=1 x_{m+1} \dots x_n ] \\
    g_m [ y_{n-m+1}=1 y_{n-m+2} \dots y_n  x_{m+1} \dots x_n ] =   \\
    g_{m-1}[ y_{n-m+2} \dots y_n  x_{m}=0 x_{m+1} \dots x_n ] - e^{2\pi i\times (0.0y_{n-m+2}\dots y_n)}
    g_{m-1}[ y_{n-m+2} \dots y_n  x_{m}=1 x_{m+1} \dots x_n ]
と変換を繰り返す。
y_{n-m+2} \dots y_nx_{m+1} \dots x_nをそれぞれ2進数とみなして
\[ y_t= y_{n-m+2} \dots y_n \]
\[ x_t= x_{m+1} \dots x_n \]
と書くことにすると、

    g_m[y_t2^{n-m} + x_t ] =
    g_{m-1}[y_t 2^{n-m+1} + x_t] + e^{\frac{y_t}{2^{m-1}}\pi i}
    g_{m-1}[y_t 2^{n-m+1} + 2^{n-m} + x_t]\\
    g_m[2^{n-1} + y_t2^{n-m} + x_t] =
    g_{m-1}[y_t 2^{n-m+1} + x_t] - e^{\frac{y_t}{2^{m-1}}\pi i}
    g_{m-1}[y_t 2^{n-m+1} + 2^{n-m} + x_t]
この変換を 0\le xt \le 2^{n-m}-1, 0\le yt \le 2^{m-1}-1の全ての値について計算すれば良い。

atc001.contest.atcoder.jp
一応ACとれたのであってるはず。微妙だが。

        #include <vector> 
        #include <list> 
        #include <map>
        #include <set>
        #include <deque>
        #include <stack>
        #include <bitset>
        #include <algorithm>
        #include <functional>
        #include <numeric>
        #include <utility>
        #include <sstream>
        #include <iostream>
        #include <iomanip>
        #include <cstdio>
        #include <cmath>
        #include <cstdlib>
        #include <cctype>
        #include <string>
        #include <cstring>
        #include <ctime>
        #include <queue>
        #include <complex>
        using namespace std;
        
        //conversion
        //------------------------------------------
        inline int toInt(string s) { int v; istringstream sin(s); sin >> v; return v; }
        template<class T> inline string toString(T x) { ostringstream sout; sout << x; return sout.str(); }
        
        //math
        //-------------------------------------------
        template<class T> inline T sqr(T x) { return x * x; }
        
        //typedef
        //------------------------------------------
        typedef vector<int> VI;
        typedef vector<VI> VVI;
        typedef vector<string> VS;
        typedef pair<int, int> PII;
        typedef long long LL;
        
        //container util
        //------------------------------------------
        #define ALL(a)  (a).begin(),(a).end()
        #define RALL(a) (a).rbegin(), (a).rend()
        #define PB push_back
        #define MP make_pair
        #define SZ(a) int((a).size())
        #define EACH(i,c) for(typeof((c).begin()) i=(c).begin(); i!=(c).end(); ++i)
        #define EXIST(s,e) ((s).find(e)!=(s).end())
        #define SORT(c) sort((c).begin(),(c).end())
        
        //repetition
        //------------------------------------------
        #define FOR(i,a,b) for(int i=(a);i<(b);++i)
        #define REP(i,n)  FOR(i,0,n)
        
        //constant
        //--------------------------------------------
        const double EPS = 1e-10;
        const double PI = acos(-1.0);
        
        //clear memory
        #define CLR(a) memset((a), 0 ,sizeof(a))
        
        //debug
        #define dump(x)  cerr << #x << '=' << (x) << endl;
        #define debug(x) cerr << #x << '=' << (x) << '('<<'L' << __LINE__ << ')' << ' ' << __FILE__ << endl;
        
        typedef complex<double> Comp;
        const Comp img = Comp(0, 1);
        
        void FFT(vector<Comp>& f, int n) {
            int N = 1 << n;
            vector<Comp> g(N);
         
            for (int m = 1; m <= n; m++) {//どこのbitに着目しているか m番目のbitに着目している
                int max_xt = (1 << (n - m)) - 1;//これはxtの値の取りうる範囲 2^{n-m}
                int max_yt = (1 << (m - 1)) - 1;//これはytの値の取りうる範囲 2^{m-1}
         
                int start_yt = (1 << (n - m + 1));
         
                Comp a, b;
                //xt ytの取りうるすべての値を列挙
                for (int xt = 0; xt <= max_xt; xt++)for (int yt = 0; yt <= max_yt; yt++) {
                    a = f[yt*start_yt/*---yt 2^{n-m+1}---*/ + xt];//x_m==0の場合
                    b = f[yt*start_yt + (1 << (n - m)) + xt] * exp(PI*img*double(yt) / double(1 << (m - 1)));//x_m==1の場合
         
                    g[(yt << (n - m))/*---yt 2^{n-m}---*/ + xt] = a + b;//Y_m==0の場合
                    g[(1 << (n - 1))/*2^{n-1}*/ + (yt << (n - m))/*yt 2^{n-m}*/ + xt] = a - b;//Y_m==1の場合 
        
                }
                f = g;//次のループはこのgをfとして扱う
            }
        }
         
        //逆フーリエ変換をする
        //ζnをinverse(ζn)に変換(複素共益を取る) してFFTし最後にNで割るだけ
        void iFFT(vector<Comp>& f, int n) {
            int N = 1 << n;
            vector<Comp> g(N);
            for (int m = 1; m <= n; m++) {
                int max_xt = (1 << (n - m)) - 1;
                int max_yt = (1 << (m - 1)) - 1;
         
                int start_yt = (1 << (n - m + 1));
         
                Comp a, b;
                for (int xt = 0; xt <= max_xt; xt++)for (int yt = 0; yt <= max_yt; yt++) {
                    a = f[yt*start_yt + xt];
                    b = f[yt*start_yt + (1 << (n - m)) + xt] * exp(PI*(-img)*double(yt) / double(1 << (m - 1)));
         
                    g[(yt << (n - m)) + xt] = a + b;
                    g[(1 << (n - 1)) + (yt << (n - m)) + xt] = a - b;
                }
                f = g;
            }
         
            REP(i, N) {
                f[i] /= (double)N;
            }
        }
        #if 0
        int main() {
         
            const int n = 3;
            int N = 1 << n;
            vector<Comp> f(N);
            const double k = 3.0;
            REP(x, N) {
                f[x] = exp(-2 * PI*img*k*double(x) / double(N));
            }
            cout << "f=" << endl;
         
            REP(i, N) {
                cout << f[i] << " ";
            }
            cout << endl;
         
            auto a = FFT(f, n);
         
            cout << "FFT=" << endl;
            REP(i, N) {
                cout << a[i] << " ";
            }
            cout << endl;
        }
        #else
        int main() {
            int n;
            cin >> n;
            int a = 1;
            while (true) {
                if ((1 << a)>2 * n) { break; }
                a++;
            }
            
            int N = 1 << a;
            vector<Comp> A(N);
            vector<Comp> B(N);
         
            FOR(i, 0, n) {
                int x, y;
                cin >> x >> y;
                A[i] = Comp(x, 0), B[i] = Comp(y, 0);
            }
         
            FFT(A, a);
            FFT(B, a);
            REP(i, N) {
                A[i] *= B[i];
            }
            iFFT(A, a);
            cout << 0 << endl;
            FOR(i, 0, 2 * n - 1) {
                cout << (int)round(A[i].real()) << endl;
            }
        }
        #endif