P4245-【模板】任意模数NTT

目录

P4245-【模板】任意模数 NTT

给定 22 个多项式 F(x),G(x)F(x), G(x) ,请求出 F(x)G(x)F(x) * G(x)

系数对 pp 取模,且不保证 pp 可以分解成 p=a2k+1p = a \cdot 2^k + 1 之形式。

输入共 33 行。
第一行 33 个整数 n,m,pn, m, p,分别表示 F(x),G(x)F(x), G(x) 的次数以及模数 pp
第二行为 n+1n+1 个整数, 第 ii 个整数 aia_i 表示 F(x)F(x)i1i-1 次项的系数。
第三行为 m+1m+1 个整数, 第 ii 个整数 bib_i 表示 G(x)G(x)i1i-1 次项的系数。

输出 n+m+1n+m+1 个整数, 第 ii 个整数 cic_i 表示 F(x)G(x)F(x) * G(x)i1i-1 次项的系数。

1
2
3
5 8 28
19 32 0 182 99 95
77 54 15 3 98 66 21 20 38
1
7 18 25 19 5 13 12 2 9 22 5 27 6 26
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
#include "ybwhead/ios.h"
int mod;
inline int pw(int base, int p, const int mod)
{
    static int res;
    for (res = 1; p; p >>= 1, base = static_cast<long long>(base) * base % mod)
        if (p & 1)
            res = static_cast<long long>(res) * base % mod;
    return res;
}
inline int inv(int x, const int mod) { return pw(x, mod - 2, mod); }

const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const long long mod_1_2 = static_cast<long long>(mod1) * mod2;
const int inv_1 = inv(mod1, mod2), inv_2 = inv(mod_1_2 % mod3, mod3);
struct Int
{
    int A, B, C;
    explicit inline Int() {}
    explicit inline Int(int __num) : A(__num), B(__num), C(__num) {}
    explicit inline Int(int __A, int __B, int __C) : A(__A), B(__B), C(__C) {}
    static inline Int reduce(const Int &x)
    {
        return Int(x.A + (x.A >> 31 & mod1), x.B + (x.B >> 31 & mod2), x.C + (x.C >> 31 & mod3));
    }
    inline friend Int operator+(const Int &lhs, const Int &rhs)
    {
        return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
    }
    inline friend Int operator-(const Int &lhs, const Int &rhs)
    {
        return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
    }
    inline friend Int operator*(const Int &lhs, const Int &rhs)
    {
        return Int(static_cast<long long>(lhs.A) * rhs.A % mod1, static_cast<long long>(lhs.B) * rhs.B % mod2, static_cast<long long>(lhs.C) * rhs.C % mod3);
    }
    inline int get()
    {
        long long x = static_cast<long long>(B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
        return (static_cast<long long>(C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
    }
};

#define maxn 131072

#define N (maxn << 1)
int lim, s, rev[N];
Int Wn[N | 1];
inline void init(int n)
{
    s = -1, lim = 1;
    while (lim < n)
        lim <<= 1, ++s;
    for (register int i = 1; i < lim; ++i)
        rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
    const Int t(pw(G, (mod1 - 1) / lim, mod1), pw(G, (mod2 - 1) / lim, mod2), pw(G, (mod3 - 1) / lim, mod3));
    *Wn = Int(1);
    for (register Int *i = Wn; i != Wn + lim; ++i)
        *(i + 1) = *i * t;
}
inline void NTT(Int *A, const int op = 1)
{
    for (register int i = 1; i < lim; ++i)
        if (i < rev[i])
            std::swap(A[i], A[rev[i]]);
    for (register int mid = 1; mid < lim; mid <<= 1)
    {
        const int t = lim / mid >> 1;
        for (register int i = 0; i < lim; i += mid << 1)
        {
            for (register int j = 0; j < mid; ++j)
            {
                const Int W = op ? Wn[t * j] : Wn[lim - t * j];
                const Int X = A[i + j], Y = A[i + j + mid] * W;
                A[i + j] = X + Y, A[i + j + mid] = X - Y;
            }
        }
    }
    if (!op)
    {
        const Int ilim(inv(lim, mod1), inv(lim, mod2), inv(lim, mod3));
        for (register Int *i = A; i != A + lim; ++i)
            *i = (*i) * ilim;
    }
}
#undef N

int n, m;
Int A[maxn << 1], B[maxn << 1];
int main()
{
    yin >> n >> m >> mod;
    ++n, ++m;
    for (int i = 0, x; i < n; ++i)
        yin >> x, A[i] = Int(x % mod);
    for (int i = 0, x; i < m; ++i)
        yin >> x, B[i] = Int(x % mod);
    init(n + m);
    NTT(A), NTT(B);
    for (int i = 0; i < lim; ++i)
        A[i] = A[i] * B[i];
    NTT(A, 0);
    for (int i = 0; i < n + m - 1; ++i)
    {
        yout << A[i].get() << " ";
    }
    return 0;
}

来发评论吧~
Powered By Valine
v1.4.14