Senの競技プログラミング備忘録

こけた問題を自分用の解説で載せる。けんちょんさんのブログを目指したい。質的にも量的にも。こけた問題だけに限定するけど

NTT(数論変換)のやさしい解説

この記事は、高速フーリエ変換の1つのバリエーションである、Numeric Theory Translation(数論変換)の詳しい解説記事です。 あまりなかったので書きました。 数論変換にありがちな

  • なんで$ 10 ^ 9 + 7 $じゃダメなのか
  • 原始根って何?

とかについてもこれを見ればわかります。

前提知識としては、Fast Fourier Translation(高速フーリエ変換)が必要です。 kaage大先生のQiita記事とかは中高生にもわかりやすく書かれています。

下に一応自分の勉強ノートを載せます。(上の記事の行間を埋めた感じです)

では、数論変換について説明したいと思います。

高速フーリエ変換の弱点

高速フーリエ変換は正しいですが、弱点として精度が足りないというのがあります。 なぜならば、複素数の1の$ 2 ^ m $乗根は計算上ではれっきとした64bit倍精度浮動小数点のdoubleのペアなので、$ m $が増えるとdoubleではどうしても精度が足らなくなります。

一例として、$ (10x ^ 9 + 9x ^ 8 + 8x ^ 7 + 7x ^ 6 + 6x ^ 5 + 5x ^ 4 + 4x ^ 3 + 3x ^ 2 + 2x + 1) ^ 2 $の畳み込みを見てみましょう。自作のFFTの計算結果は以下です。

(1, 7.43086e-14)
(4, 1.03461e-13)
(10, 1.03814e-13)
(20, 1.16975e-13)
(35, 1.51601e-13)
(56, 1.64745e-13)
(84, 1.56105e-13)
(120, 1.48062e-13)
(165, 1.53148e-13)
(220, 1.27886e-13)
(264, 8.78265e-14)
(296, 2.90458e-14)
(315, -6.56697e-14)
(320, -1.01708e-13)
(310, -1.21006e-13)
(284, -1.0951e-13)
(241, -1.74825e-13)
(180, -1.87862e-13)
(100, -1.75074e-13)

実数部分は問題ありませんが、複素部分は$ 10 ^ {-13} $オーダーのエラーが出ています。倍精度浮動小数点の計算機イプシロン(double型で表す1より大きい最小の数)が$ 2.220 \times 10 ^ {-16} $ であることを考えると、なかなかに無視できない誤差ではないかとわかります。

実際、畳み込みの結果が$ 10 ^ {8} $オーダーになると、倍精度浮動小数点では精度が足りなくなります。そして、畳み込みなどを使う数え上げではこの限界をいとも簡単に突破する問題が多いです。

FFTアルゴリズム自体は正しいのですが、計算機の精度保持の限界の問題で大きな値は正しく計算できないのです。

FFTのn乗根

FFTのn乗根は、複素数$ \zeta _ n = (\cos \frac{2 \pi}{n} + i \sin \frac{2 \pi}{n}) $が採用されていました。この値は、以下の性質を持っていました。

  • $ \zeta _ n ^ 0, \zeta _ n ^ 1, \cdots , \zeta _ n ^ {n - 1} $の値はみんな違う。
  • $ \zeta _ n ^ n = 1 $(n乗すると単位ユニット(=単位元)になる)
  • $ \sum _ {i = 0} ^ {n - 1} (\zeta _ n ^ k) ^ i $は$ 0(k \neq 0) $か$ n(k = 0) $

逆に言うとこれを満たすものであれば、離散フーリエ変換に代入する値としてOKです。(FFTをするために$ n = 2 ^ m $の形としないといけないけど)

複素数の世界では、候補はこれしかありませんでした。しかし、世界を変えてみるとどうなるでしょう?

mod pの世界

整数を$ p $で割ったあまりの世界(剰余環)を考えます。この世界での四則演算について考えると、割り算以外なんでもできそうな気がします。 そして、$ p $が素数の場合、割り算も定義できます。(けんちょん大先生の記事を参考)

そして、この世界で、$ n $乗したら1になるものも$ n $乗根といいます。

mod pの世界の原始根

ここで、mod $ p $($ p $は素数)の世界の原始根という考え方を説明します。

mod $ p $の世界で、ある要素$ g $があって、$ g ^ 0, g ^ 1, g ^ 2, \cdots , g ^ {p - 2} $がみんな違う値を取るような$ g $を原始根といいます。

詳しいのはけんちょんさんと同じ研究室の人が書いた高校数学の美しい物語の記事をご覧ください。

mod pの世界での離散フーリエ変換

mod $ p $(素数)の世界でも、$ n=2 ^ m $乗根にあたるものを考えてみます。これは$ g ^ n = 1 $を満たす必要があります。ところで、mod $ p $の世界で有名なフェルマーの小定理があり、 $$ g ^ {p - 1} \equiv 1 (mod p) $$

が成り立ちます。このことから、$ p = 2 ^ m \times a + 1 $($ a $は2と互いに素)と書ける場合、 $$ g ^ {2 ^ m \times a} \equiv 1 (mod p), (g ^ {a}) ^ {2 ^ m} \equiv 1 (mod p) $$

となります。つまり、 $ 2 ^ m $乗根は$ g ^ {a} $ となります。(一方で、$ 2 ^ {n + 1} $乗根は存在しません)

これで、複素数の上で定義していた$ 2 ^ m $乗根を定めることができました。 これを2乗することで、$ 2 ^ {m - 1} $乗根も複素と同様に定めることができますので、mod $ p $の上の1乗根、2乗根、4乗根、...、$ 2 ^ {m} $乗根と全て求められます。

mod $ p $の上での割り算も定義されているので、FFTをmod $ p $の世界で考えることができました。

$ 10 ^ 9+7 $と$ 998244353 $の違い

NTT有名素数modとして、$ 998244353 $が挙げられます。この値は、 $$ 998244353 = 2 ^ {23} \times 119 + 1 $$ となります。

これは、mod $ 998244353 $では$ 2 ^ {23} $乗根まで存在する($ 2 ^ {24} $乗根はない)ということです。 $ 2 ^ {23} $乗根まで存在するので、mod $ 998244353 $の世界で$ 2 ^ {23} = 8300608 $から、$ 8388607 $次までの畳み込みを、FTTと同じように分割統治で計算することができます。

このことから、競プロでは一般的に実用的な畳み込みの余りは$ 998244353 $となっています。

一方、$ 10 ^ 9 + 7 $の場合 $$ 10 ^ 9 + 7 = 2 ^ 1 \times 50000003 + 1 $$ であります。

これは、mod $ 10 ^ 9 + 7 $ の世界では、2乗根までしかなく、4乗根すらないということです。 つまり1次式までしか畳み込みできません。 さすがにこれは使えないです。

$ 10 ^ 9+7 $で無理やり畳み込みをするには?

mod $ 10 ^ 9+7 $は1の4乗根が存在しないことから、畳み込みはかなり困難になります。一応頑張ればできるらしいのですが、計算量的にも実装量的にも手法的にも全く実用的ではありません。しかし、任意modでの畳み込みをしたいというのならば、複数回のNTTとGarnerのアルゴリズムというもので復元できます。

やり方の詳細はここにあります。

実装例

確認問題

//mintはModintである。
//畳み込みをする前にsetup()を実行する。
typedef std::vector<mint> vectorM;//NTT用のmintのベクター型
const int DIVIDE_LIMIT = 23;//99...の有名素数は23回分割統治できる。
mint ROOT[DIVIDE_LIMIT + 1];//[i]は2^i乗根 99...の有名素数の原始根は3で、そこから2^22乗根, 2^21...などをsetup()で計算する。
mint inv_ROOT[DIVIDE_LIMIT + 1];//[i]は2^i乗根の逆数 setup()で計算する。
mint PRIMITIVE_ROOT = 3;

void setup() {
    ROOT[DIVIDE_LIMIT] = modpow(PRIMITIVE_ROOT, (MOD - 1) / modpow(2, 23).val);//99..なら119乗
    inv_ROOT[DIVIDE_LIMIT] = 1 / ROOT[DIVIDE_LIMIT];
    for (int i = DIVIDE_LIMIT - 1; i >= 0; i--) {
        ROOT[i] = ROOT[i + 1] * ROOT[i + 1];
        inv_ROOT[i] = inv_ROOT[i + 1] * inv_ROOT[i + 1];
    }
}

vectorM ntt(const vectorM& f, const int inverse, const int log2_f, const int divide_cnt = DIVIDE_LIMIT) {
    vectorM ret;
    if (f.size() == 1 || divide_cnt == 0) {
        ret.resize(f.size());
        mint zeta = 1;
        for (int i = 0; i < ret.size(); i++) {
            mint now = zeta;
            for (int j = 0; j < f.size(); j++) {
                ret[i] += f[j] * now;
                now *= zeta;
            }
            zeta *= ((inverse == 1) ? ROOT[0] : inv_ROOT[0]);
        }
        return ret;
    }

    vectorM f1(f.size() / 2), f2(f.size() / 2);
    //f1とf2を作る。
    for (int i = 0; i < f.size() / 2; i++) {
        f1[i] = f[i * 2];
        f2[i] = f[i * 2 + 1];
    }

    vectorM f1_dft = ntt(f1, inverse, log2_f - 1, divide_cnt  -1), f2_dft = ntt(f2, inverse, log2_f - 1, divide_cnt - 1);
    ret.resize(f.size());
    mint now = 1;

    for (int i = 0; i < f.size(); i++) {
        ret[i] = f1_dft[i % f1_dft.size()] + now * f2_dft[i % f2_dft.size()];
        now *= ((inverse == 1) ? ROOT[log2_f] : inv_ROOT[log2_f]);
    }
    return ret;
}

//eraseHigh0は高次項が係数ゼロ、vectorから排除するかどうか
vectorM mulp(const vectorM& _f, const vectorM& _g) {
    vectorM f = _f, g = _g;

    //fとgの次数の和以上の最小の2冪-1を次数とする。
    int max_dim = 1, log2_max_dim = 0;
    while (f.size() + g.size() > max_dim) max_dim <<= 1, log2_max_dim++;
    f.resize(max_dim), g.resize(max_dim);
    //多項式fとgのDFT結果を求める。 O(n log n)
    vectorM f_dft = ntt(f, 1, log2_max_dim), g_dft = ntt(g, 1, log2_max_dim);

    //f*gのDFT結果は各f_dftとg_ftの係数の積。O(n)
    vectorM fg_dft(max_dim);
    for (int i = 0; i < max_dim; i++) {
        fg_dft[i] = f_dft[i] * g_dft[i];
    }

    //fg_dftをDFT
    vectorM fg = ntt(fg_dft, -1, log2_max_dim);

    //最後にmax_dimで割る
    for (int i = 0; i < fg.size(); i++) {
        fg[i] = fg[i] / max_dim;
    }
    return fg;
}

最後に

NTTについてかなり一杯学べました。これがみんなの役に立てれば幸いです。

参考文献