任意模数NTT

时间:2019-12-18
本文章向大家介绍任意模数NTT,主要包括任意模数NTT使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

给定两个多项式\(f(x),g(x)\),求\(h(x)=f(x)*g(x)\)

\(p\)取模,\(p\)不保证可以分解成\(a*2^k+1\)

三模NTT

由于模数不满足有原根,我们可以找几个有原根的模数,求出结果再\(crt\)合并一下

考虑卷积后数字最大能达到\(p*p*len\),一般是\(10^9*10^9*10^5=10^{23}\),所以我们选择模数之积应该大于\(10^{23}\)

一般考虑\(998244353,1004535809,469762049\),因为原根都是\(3\),在\(int\)范围内且乘积较大

然后跑三个\(ntt\)(天体常数)得到

\[\begin{cases}ret\equiv a_1 (mod\ p_1)\\ret\equiv a_2 (mod\ p_2)\\ret\equiv a_3 (mod\ p_3)\\ret\equiv x (mod\ p)\end{cases}\]

我们现在就是求\(x\)

由于前三个式子直接合并爆\(long\ long\),我们用一些技巧:

先合并前两个,定义\(inv(x,p)\)\(x\)在模\(p\)意义下的逆元

\[ret\equiv a_1*p_2*inv(p_2,p_1)+a_2*p_1*inv(p_1,p_2) (mod\ p_1*p_2)\]

记作

\[ret\equiv d (mod\ m)\]

\[ret=x*m+d=y*p_3+a_3\]

\[x\equiv (a_3-d)*M^{-1} (mod\ p_3)\]

\(q=(a_3-d)*M^{-1}\),那么

\[x=k*p_3+q\]

代入\(ret\)得到

\[ret=k*p_1*p_2*p_3+q*M+d\]

然而\(ans\in [0,p_1*p_2*p_3)\),所以\(k=0\)\(ret\)就得出了

#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define eps (1e-8)
    inline int read()
    {
        int x=0;char ch,f=1;
        for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
        if(ch=='-') f=0,ch=getchar();
        while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
        return f?x:-x;
    }
    const int N=4e5+10;
    int mod[3]={469762049,998244353,1004535809};
    int n,m,p;
    int f[N],g[N],pos[N];
    int b[N],ret[N];
    int limit,len;
    inline int fast(int x,int k,int p)
    {
        int ret=1;
        while(k)
        {
            if(k&1) ret=ret*x%p;
            x=x*x%p;
            k>>=1;
        }
        return ret;
    }
    inline int slow(int x,int k,int p)
    {
        int ret=0;
        while(k)
        {
            if(k&1) ret=(ret+x)%p;
            x=(x+x)%p;
            k>>=1;
        }
        return ret;
    }
    struct poly
    {
        int g=3,p,a[N];
        inline void ntt(int limit,int *a,int inv)
        {
            for(int i=0;i<limit;++i)
                if(i<pos[i]) swap(a[i],a[pos[i]]);
            for(int mid=1;mid<limit;mid<<=1)
            {
                int Wn=fast(inv?g:(p+1)/g,(p-1)/(mid<<1),p);
                for(int r=mid<<1,j=0;j<limit;j+=r)
                {
                    int w=1;
                    for(int k=0;k<mid;++k,w=w*Wn%p)
                    {
                        int x=a[j+k],y=w*a[j+k+mid]%p;
                        a[j+k]=x+y;
                        if(a[j+k]>=p) a[j+k]-=p;
                        a[j+k+mid]=x-y;
                        if(a[j+k+mid]<0) a[j+k+mid]+=p;
                    }
                }
            }
            if(inv) return;
            inv=fast(limit,p-2,p);
            for(int i=0;i<limit;++i) a[i]=a[i]*inv%p;
        }
    }fft[3];
    inline int inv(int x,int p)
    {
        return fast(x%p,p-2,p);
    }
    inline void crt()
    {
        int len=n+m;
        int M=mod[0]*mod[1];
        int inv1=inv(mod[1],mod[0]),inv0=inv(mod[0],mod[1]),inv3=inv(M%mod[2],mod[2]);
        int a,b,c,t,k;
        for(int i=0;i<=len;++i)
        {
            a=fft[0].a[i],b=fft[1].a[i],c=fft[2].a[i];
            t=(slow(a*mod[1]%M,inv1,M)+slow(b*mod[0]%M,inv0,M))%M;
            k=((c-t%mod[2])%mod[2]+mod[2])%mod[2]*inv3%mod[2];
            ret[i]=((k%p)*(M%p)%p+t%p)%p;
        }
    }
    inline void main()
    {
        n=read(),m=read(),p=read();
        for(int i=0;i<=n;++i) f[i]=read();
        for(int i=0;i<=m;++i) g[i]=read();
        for(limit=1;limit<=n+m+2;limit<<=1) ++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        for(int k=0;k<=2;++k)
        {
            fft[k].p=mod[k];
            for(int i=0;i<=n;++i) fft[k].a[i]=f[i];
            for(int i=0;i<=m;++i) b[i]=g[i];
            for(int i=m+1;i<limit;++i) b[i]=0;
            fft[k].ntt(limit,fft[k].a,1);
            fft[k].ntt(limit,b,1);
            for(int i=0;i<limit;++i)  fft[k].a[i]=fft[k].a[i]*b[i]%mod[k];
            fft[k].ntt(limit,fft[k].a,0);
        }
        crt();
        for(int i=0;i<=n+m;++i) printf("%lld ",ret[i]);
    }
}
signed main()
{
    red::main();
    return 0;
}

其他方法先咕咕咕

原文地址:https://www.cnblogs.com/knife-rose/p/12058915.html