ACM_Notebook_new

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ngthanhtrung23/ACM_Notebook_new

:heavy_check_mark: Math/tests/ntt_chemthan.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include <bits/stdc++.h>
using namespace std;

#include "../Polynomial/NTT_chemthan.h"

NTT<998244353, 1<<20> ntt;

#define REP(i, a) for (int i = 0, _##i = (a); i < _##i; ++i)

int32_t main() {
    ios::sync_with_stdio(0); cin.tie(0);
    int n, m; cin >> n >> m;
    vector<int> a(n); REP(i,n) cin >> a[i];
    vector<int> b(m); REP(i,m) cin >> b[i];

    auto c = ntt.multiply(a, b);
    for (int x : c) cout << x << ' ';
    cout << endl;
    return 0;
}
#line 1 "Math/tests/ntt_chemthan.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include <bits/stdc++.h>
using namespace std;

#line 1 "Math/Polynomial/NTT_chemthan.h"
// Copied from chemthan
// 2x slower than atcoder library
// Tested:
// - https://oj.vnoi.info/problem/icpc21_mt_d
// - https://judge.yosupo.jp/problem/convolution_mod
// - https://judge.yosupo.jp/problem/convolution_mod_1000000007
//
// Sample usage: Multiply big-int polynomials using NTT + CRT
//   NTT<MOD0, 1 << 21> ntt0;
//   NTT<MOD1, 1 << 21> ntt1;
//   auto r0 = ntt0.multiply(v1, v2);
//   auto r1 = ntt1.multiply(v1, v2);
//
//   // Using CRT to combine r0 and r1
//   CRT<int> crt;
//   crt.add(MOD0, r0[idx]);
//   crt.add(MOD1, r1[idx]);
//   cout << crt.res << endl;

// mod must be NTT mod
// maxf = max degree of c. Should be 2^k?
template<const int mod, const int maxf>
struct NTT {
    NTT() {
        int k = 0; while ((1 << k) < maxf) k++;
        bitrev[0] = 0;
        for (int i = 1; i < maxf; i++) {
            bitrev[i] = bitrev[i >> 1] >> 1 | ((i & 1) << (k - 1));
        }
        int pw = fpow(prt(), (mod - 1) / maxf);
        rts[0] = 1;
        for (int i = 1; i <= maxf; i++) {
            rts[i] = (long long) rts[i - 1] * pw % mod;
        }
        for (int i = 1; i <= maxf; i <<= 1) {
            iv[i] = fpow(i, mod - 2);
        }
    }

    vector<int> multiply(vector<int> a, vector<int> b) {
        static int fa[maxf], fb[maxf], fc[maxf];
        int na = a.size(), nb = b.size();
        for (int i = 0; i < na; i++) fa[i] = a[i];
        for (int i = 0; i < nb; i++) fb[i] = b[i];
        multiply(fa, fb, na, nb, fc);
        int k = na + nb - 1;
        vector<int> res(k);
        for (int i = 0; i < k; i++) res[i] = fc[i];
        return res;
    }

private:
    int rts[maxf + 1];
    int bitrev[maxf];
    int iv[maxf + 1];

    int fpow(int a, int k) {
        if (!k) return 1;
        int res = a, tmp = a;
        k--;
        while (k) {
            if (k & 1) {
                res = (long long) res * tmp % mod;
            }
            tmp = (long long) tmp * tmp % mod;
            k >>= 1;
        }
        return res;
    }
    int prt() {
        vector<int> dvs;
        for (int i = 2; i * i < mod; i++) {
            if ((mod - 1) % i) continue;
            dvs.push_back(i);
            if (i * i != mod - 1) dvs.push_back((mod - 1) / i);
        }
        for (int i = 2; i < mod; i++) {
            int flag = 1;
            for (int j = 0; j < (int) dvs.size(); j++) {
                if (fpow(i, dvs[j]) == 1) {
                    flag = 0;
                    break;
                }
            }
            if (flag) return i;
        }
        assert(0);
        return -1;
    }
    void dft(int a[], int n, int sign) {
        int d = 0; while ((1 << d) * n != maxf) d++;
        for (int i = 0; i < n; i++) {
            if (i < (bitrev[i] >> d)) {
                swap(a[i], a[bitrev[i] >> d]);
            }
        }
        for (int len = 2; len <= n; len <<= 1) {
            int delta = maxf / len * sign;
            for (int i = 0; i < n; i += len) {
                int *w = sign > 0 ? rts : rts + maxf;
                for (int k = 0; k + k < len; k++) {
                    int &a1 = a[i + k + (len >> 1)], &a2 = a[i + k];
                    int t = (long long) *w * a1 % mod;
                    a1 = a2 - t;
                    a2 = a2 + t;
                    a1 += a1 < 0 ? mod : 0;
                    a2 -= a2 >= mod ? mod : 0;
                    w += delta;
                }
            }
        }
        if (sign < 0) {
            int in = iv[n];
            for (int i = 0; i < n; i++) {
                a[i] = (long long) a[i] * in % mod;
            }
        }
    }
    void multiply(int a[], int b[], int na, int nb, int c[]) {
        static int fa[maxf], fb[maxf];
        int n = na + nb - 1; while (n != (n & -n)) n += n & -n;
        for (int i = 0; i < n; i++) fa[i] = fb[i] = 0;
        for (int i = 0; i < na; i++) fa[i] = a[i];
        for (int i = 0; i < nb; i++) fb[i] = b[i];
        dft(fa, n, 1), dft(fb, n, 1);
        for (int i = 0; i < n; i++) fa[i] = (long long) fa[i] * fb[i] % mod;
        dft(fa, n, -1);
        for (int i = 0; i < n; i++) c[i] = fa[i];
    }
};

/* Examples
const int MOD0 = 1004535809; //2^21 * 479 + 1
const int MOD1 = 1012924417; //2^21 * 483 + 1
const int MOD2 = 998244353;  //2^20 * 476 + 1
NTT<MOD0, 1 << 21> ntt0;
NTT<MOD1, 1 << 21> ntt1;
*/
#line 7 "Math/tests/ntt_chemthan.test.cpp"

NTT<998244353, 1<<20> ntt;

#define REP(i, a) for (int i = 0, _##i = (a); i < _##i; ++i)

int32_t main() {
    ios::sync_with_stdio(0); cin.tie(0);
    int n, m; cin >> n >> m;
    vector<int> a(n); REP(i,n) cin >> a[i];
    vector<int> b(m); REP(i,m) cin >> b[i];

    auto c = ntt.multiply(a, b);
    for (int x : c) cout << x << ' ';
    cout << endl;
    return 0;
}
Back to top page