NTT

时间:2021-07-12
本文章向大家介绍NTT,主要包括NTT使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

单位根的定义就不说了。

显然有:

\[\omega_n^k=\cos {2\pi k \over n}+i\sin {2\pi k \over n} \]

带入这个可以直接证得:

\[\omega_{2n}^{2k}=\cos {2\pi \cdot 2k \over 2n}+i\sin {2\pi \cdot 2k \over2 n}=\omega_n^k \]

我们用图像理性理解可得(就是绕着原点把向量转了180度):

\[\omega_n^{k+{n\over 2}}=-\omega_n^k \]

依然根据图像可得(就是绕着原点旋转了360度):

\[\omega_n^{k+n}=\omega_n^k \]

由欧拉公式得:

\[\omega_n^k=\cos {2\pi k \over n}+i\sin {2\pi k \over n}=e^{i\cdot{2\pi k \over n}} \]

所以有:

\[\omega_n^k=e^{i\cdot{2\pi k \over n}}=(e^{i\cdot{2\pi \over n}})^k=(\omega_n^1)^k \]

我们先带入\(\omega_n^1,\omega_n^2,...,\omega_n^n\)到多项式\(A\)中,求出\(A(\omega_n^1),A(\omega_n^2),...,A(\omega_n^n)\)

为了方便,假设长度\(n\)\(2^k\)(高位不够的添0即可)

把A下标奇偶分类:

\[A_1(x)=a_0+a_2 x+a_4 x^2+... \]
\[A_2(x)=a_1+a_3 x+a+5 x^2+... \]

显然有:

\[A(x)=A_1(x^2)+xA_2(x^2) \]
\[\therefore A(\omega_{n}^{k})=A_1(\omega_{n}^{2k})+\omega_{n}^{k}A_2(\omega_{n}^{2k}) \]
\[=A_1(\omega_{n\over 2}^{k})+\omega_{n}^{k}A_2(\omega_{n\over 2}^{k}) \]
\[A(\omega_{n}^{k+{n\over 2}})=A_1(\omega_{n}^{2k+n})-\omega_{n}^{k}A_2(\omega_{n}^{2k+n}) \]
\[=A_1(\omega_{n\over 2}^{k})-\omega_{n}^{k}A_2(\omega_{n\over 2}^{k}) \]

这两个式子只有后面一项是相反的,可以递归求解。

于是给出代码:

inline void FFT(complex<double> *a, int len) {
	if (!len) return ; complex<double> a1[len], a2[len];
	for (int i = 0; i < len; ++i) a1[i] = a[i << 1], a2[i] = a[i << 1 | 1];
	FFT(a1, len >> 1); FFT(a2, len >> 1);
	complex<double> w(cos(PI / len), sin(PI / len)), wk(1, 0);
	for (int i = 0; i < len; ++i, wk *= w)
		a[i] = a1[i] + wk * a2[i], a[i + len] = a1[i] - wk * a2[i];
}

考虑怎么从点值多项式转换到系数多项式。

我们钦定\(y_i=A(\omega_n^i)\),在有一多项式\(C\),满足:

\[C(x)=\sum y_i x^i \]

则我们带入\(\omega_n^{-k}\),得到:

\[C(\omega_n^{-k})=c_k=\sum_{i=0}^{n-1} y_i (\omega_n^{-k})^i \]
\[=\sum_{i=0}^{n-1} [\sum_{j=0}^{n-1} a_j(\omega_n^{i})^j ](\omega_n^{-k})^i \]
\[=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_j(\omega_n^{j})^i(\omega_n^{-k})^i \]
\[=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_j(\omega_n^{j-k})^i \]
\[=\sum_{i=0}^{n-1} a_i \sum_{j=0}^{n-1}(\omega_n^{i-k})^j \]

设:

\[S(\omega_n^k)=\sum_{i=0}^{n-1} (\omega_n^k)^i={(\omega_n^k)^{n}-1\over \omega_n^k-1} \]

\(k\neq 0\)时为0,否则为\(n\)

则:

\[\sum_{j=0}^{n-1}(\omega_n^{i-k})^j=S(\omega_n^{i-k}) \]

即当\(i=k\)时为\(n\),所以:

\[c_k=\sum_{i=0}^{n-1} a_i \sum_{j=0}^{n-1}(\omega_n^{i-k})^j=na_k \]

我们惊讶的发现这样对\(C\)做一次FFT之后点值除以n就是多项式的系数了。

代码结合一下:

inline void FFT(complex<double> *a, int len, int flag) {
	if (!len) return ; complex<double> a1[len], a2[len];
	for (int i = 0; i < len; ++i) a1[i] = a[i << 1], a2[i] = a[i << 1 | 1];
	FFT(a1, len >> 1, flag); FFT(a2, len >> 1, flag);
	complex<double> w(cos(PI / len), flag * sin(PI / len)), wk(1, 0);
	for (int i = 0; i < len; ++i, wk *= w)
		a[i] = a1[i] + wk * a2[i], a[i + len] = a1[i] - wk * a2[i];
}

发现递归版的码会T,手玩一下发现实际上奇偶变换后下标的操作相当于二进制反过来,可以改成非递归来模拟,自己对着码手玩看看就明白。

inline void FFT(complex *a, int type) {
	for (int i = 0; i < lim; ++i)
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int mid = 1; mid < lim; mid <<= 1) {
		complex wn; wn = complex(cos(pi / mid), type * sin(pi / mid));
		for (int j = 0; j < lim; j += mid << 1) {
			complex bas; bas = complex(1, 0);
			for (int k = 0; k < mid; ++k, bas = bas * wn) {
				complex x = a[j + k], y = bas * a[j + mid + k];
				a[j + k] = x + y;
				a[j + mid + k] = x - y;
			}
		}
	}
}

解释一下,第一层循环枚举的是递归的层数,即当前合并的两个多项式的长度。第二层就是枚举当前要合并多项式的起点,第三层就是枚举的具体的那一个系数。这么说不是很清楚,还是自己造样例跟着代码手玩一下就明白了。

而NTT呢?设模数为p,g是p的原根,则不需要证明的给出,\(\omega_n^1\)等价于\(g^{p-1\over n} \bmod p\)。把上面的码代码里的wn换成这个就行了。一般p=998244353,此时g=3。

给个板子:

struct poly {
	int n;
	vector<ll> x;
	
	inline void NTT(int flag) {
		for (int i = 0; i < n; ++i)
			if (i < rev[i]) swap(x[i], x[rev[i]]);
		for (int mid = 1; mid < n; mid <<= 1) {
			ll wn = power(flag == 1 ? G : Gi, (mod - 1) / (mid << 1));
			for (int j = 0; j < n; j += mid << 1) {
				ll bas = 1;
				for (int k = 0; k < mid; ++k, bas = (bas * wn) % mod) {
					ll xx = x[j + k], y = (bas * x[j + mid + k]) % mod;
					x[j + k] = (xx + y) % mod;
					x[j + mid + k] = ((xx - y) % mod + mod) % mod;
				}
			} cerr << endl;
		}
	}
	
};

inline int max_(int a, int b) {
	return a > b ? a : b;
}

inline poly mul(poly A, poly B) {
	poly a, b; a = A; b = B;
	int tmp = a.n + b.n; a.n = 1;
	int L = 0;
	while (a.n <= tmp) a.n <<= 1, ++L;
	for (int i = 0; i <= a.n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L - 1);
	b.n = a.n;
	a.NTT(1); b.NTT(1);
	for (int i = 0; i < a.n; ++i) a.x[i] = (a.x[i] * b.x[i]) % mod;
	a.NTT(-1);
	const ll inv = power(a.n, mod - 2);
	a.n = tmp;
	for (int i = 0; i <= a.n; ++i) a.x[i] = (a.x[i] * inv) % mod;
	return a;
}

原文地址:https://www.cnblogs.com/wwlwakioi/p/15004166.html