题解 影魔

时间:2021-07-16
本文章向大家介绍题解 影魔,主要包括题解 影魔使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

传送门

dsu on tree的板子忘光了,连函数名都忘了

强制在线的做法先留个坑,
这里其实可以离线(考场上并没有想到dsu on tree能搭配离线一起用)
那问题就是如何合并信息
我们需要统计小于给定深度的有多少种颜色
发现深度越小的节点存活时间越久
所以只需要统计对于每种颜色出现的最小深度,还需要支持查询小于给定深度的颜色个数
可以用平衡树维护
但这里需要以深度排序,
splay支持从某个点分开但不支持重复元素
FHQ支持重复元素但不支持从某个点划分
所以考虑树套树,外层splay维护出现过的所有深度,
每个节点上再开一棵平衡树统计对于这个深度出现的颜色
时间复杂度\(O(nlog^2n)\)

Code:

#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long 
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long 

// https://www.cnblogs.com/StarRoadTang/p/14033777.html
// https://www.luogu.com.cn/record/51724433

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int head[N], size, c[N], qu[N], qd[N];
struct edge{int to, next;}; edge* e;
inline void add(int s, int t) {edge* k=&e[++size]; k->to=t; k->next=head[s]; head[s]=size;}

namespace force{
	int cnt, sta[N], top;
	bool vis[N];
	void dfs2(int u, int d) {
		if (d<0) return ;
		if (!vis[c[u]]) {vis[c[u]]=1; sta[++top]=c[u]; ++cnt;}
		for (int i=head[u]; i; i=e[i].next) dfs2(e[i].to, d-1);
	}
	void solve() {
		for (int i=1,u,d; i<=m; ++i) {
			cnt=0;
			while (top) vis[sta[top--]]=0;
			//memset(vis, 0, sizeof(vis));
			dfs2(qu[i], qd[i]);
			printf("%d\n", cnt);
		}
	}
}

namespace task1{
	const int SIZE=N*80;
	int tot, rot[N], tl[SIZE], tr[SIZE], cnt[SIZE], lson[SIZE], rson[SIZE];
	#define tl(a) tl[a]
	#define tr(a) tr[a]
	#define cnt(a) cnt[a]
	#define l(a) lson[a]
	#define r(a) rson[a]
	#define pushup(p) cnt(p)=cnt(l(p))+cnt(r(p))
	void upd(int& p, int l, int r, int pos) {
		if (!p) {p=++tot; tl(p)=l; tr(p)=r;}
		if (l>=r) {cnt(p)=1; return ;}
		int mid=(l+r)>>1;
		if (pos<=mid) upd(l(p), l, mid, pos);
		else upd(r(p), mid+1, r, pos);
		pushup(p);
	}
	int query(int p, int l, int r) {
		if (!p) return 0;
		if (l<=tl(p)&&r>=tr(p)) return cnt(p);
		int mid=(tl(p)+tr(p))>>1, ans=0;
		if (l<=mid) ans+=query(l(p), l, r);
		if (r>mid) ans+=query(r(p), l, r);
		return ans;
	}
	int merge(int p1, int p2, int l, int r) {
		if (!p2) return p1;
		if (!p1) {p1=++tot; tl(p1)=l; tr(p1)=r;}
		if (l>=r) {if (cnt(p1)||cnt(p2)) cnt(p1)=1; return p1;}
		int mid=(l+r)>>1;
		l(p1)=merge(l(p1), l(p2), l, mid);
		r(p1)=merge(r(p1), r(p2), mid+1, r);
		pushup(p1);
		return p1;
	}
	
	void dfs(int u) {
		upd(rot[u], 1, n, c[u]);
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			dfs(v);
			rot[u]=merge(rot[u], rot[v], 1, n);
		}
	}
	void solve() {
		//cout<<double(sizeof(tr)*5)/1024/1024<<endl;
		dfs(1);
		for (int i=1; i<=m; ++i) printf("%d\n", cnt(rot[qu[i]]));
		exit(0);
	}
	#undef pushup
}

namespace task{
	int h2[N], siz2;
	int siz[N], msiz[N], mson[N], ans[N], dep[N], rec[N], nowson;
	struct edge2{int d, next, rk;}; edge2* e2;
	inline void add2(int s, int t, int i) {edge2* k=&e2[++siz2]; k->d=t; k->next=h2[s]; k->rk=i; h2[s]=siz2;}
	
	const int SIZE=N*10;
	int tot, fa[SIZE], son[SIZE][2], val[SIZE], cnt[SIZE], size[SIZE];
	#define fa(a) fa[a]
	#define son(a, b) son[a][b]
	#define val(a) val[a]
	#define set(a) splay1[val[a]]
	#define cnt(a) cnt[a]
	#define size(a) size[a]
	#define loc(a) (son(fa(a), 1)==a)
	#define tran(a, x) son(a, x>val(a))
	#define pushup(a) size(a)=size(son(a, 0))+size(son(a, 1))+cnt(a)
	// 维护颜色
	struct Splay1{
		int rot;
		Splay1(){ins(-INF); ins(INF);}
		void ror(int x) {
			int y=fa(x), z=fa(y), k=loc(x);
			son(z, loc(y))=x; fa(x)=z;
			son(y, k)=son(x, k^1); fa(son(x, k^1))=y;
			son(x, k^1)=y; fa(y)=x;
			pushup(y); pushup(x);
		}
		void splay(int x, int pos) {
			int y, z;
			while (fa(x)!=pos) {
				y=fa(x); z=fa(y);
				if (z!=pos) loc(x)^loc(y)?ror(x):ror(y);
				ror(x);
			}
			if (!pos) rot=x;
		}
		void ins(int x) {
			int u=rot, f=0;
			while (u&&val(u)!=x) f=u, u=tran(u, x);
			if (u) puts("error1");
			else {
				u=++tot;
				if (f) tran(f, x)=u;
				size(u)=1;
				val(u)=x; fa(u)=f;
			}
			splay(u, 0);
		}
		void find(int x) {
			int u=rot;
			while (val(u)!=x&&tran(u, x)) u=tran(u, x);
			splay(u, 0);
		}
		int suf(int x, int f) {
			find(x);
			int u=rot;
			if (val(u)!=x && f^(x>val(u))) return u;
			u=son(u, f); f^=1;
			while (son(u, f)) u=son(u, f);
			return u;
		}
		void remove(int x) {
			int lst=suf(x, 0), nxt=suf(x, 1);
			splay(lst, 0); splay(nxt, lst);
			assert(son(nxt, 0));
			//puts("remove successfully!");
			son(nxt, 0)=0; splay(nxt, 0);
		}
	}splay1[N];
	// 维护深度
	struct Splay2{
		int rot;
		Splay2(){ins(-INF); ins(INF);}
		void ror(int x) {
			int y=fa(x), z=fa(y), k=loc(x);
			son(z, loc(y))=x; fa(x)=z;
			son(y, k)=son(x, k^1); fa(son(x, k^1))=y;
			son(x, k^1)=y; fa(y)=x;
			pushup(y); pushup(x);
		}
		void splay(int x, int pos) {
			int y, z;
			while (fa(x)!=pos) {
				y=fa(x); z=fa(y);
				if (z!=pos) loc(x)^loc(y)?ror(x):ror(y);
				ror(x);
			}
			if (!pos) rot=x;
		}
		void ins(int d) {
			int u=rot, f=0;
			while (u&&val(u)!=d) f=u, u=tran(u, d);
			if (u) ++cnt(u);
			else {
				u=++tot;
				if (f) tran(f, d)=u;
				cnt(u)=size(u)=1;
				val(u)=d; fa(u)=f;
			}
			splay(u, 0);
		}
		void ins(int d, int c) {
			
			int u=rot, f=0;
			while (u&&val(u)!=d) f=u, u=tran(u, d);
			if (u) ++cnt(u);
			else {
				u=++tot;
				if (f) tran(f, d)=u;
				cnt(u)=size(u)=1;
				val(u)=d; fa(u)=f;
			}
			set(u).ins(c);
			splay(u, 0);
		}
		void find(int x) {
			int u=rot;
			while (val(u)!=x&&tran(u, x)) u=tran(u, x);
			splay(u, 0);
		}
		int suf(int x, int f) {
			find(x);
			int u=rot;
			if (val(u)!=x && f^(x>val(u))) return u;
			u=son(u, f); f^=1;
			while (son(u, f)) u=son(u, f);
			return u;
		}
		void del(int d) {
			//cout<<"del "<<d<<endl;
			int lst=suf(d, 0), nxt=suf(d, 1);
			splay(lst, 0); splay(nxt, lst);
			son(nxt, 0)=0; splay(nxt, 0);
		}
		void remove(int d, int c) {
			//cout<<"remove "<<d<<' '<<c<<endl;
			find(d);
			set(rot).remove(c);
			if (--cnt(rot)<=0) del(d);
			else pushup(rot);
		}
		void puts(int p) {
			if (!p) return ;
			puts(son(p, 0));
			printf("%d ", val(p));
			puts(son(p, 1));
		}
		int query(int d) {
			splay(suf(d, 1), 0);
			//cout<<val(rot)<<endl;
			//cout<<val(son(rot, 0))<<endl;
			//puts(rot); cout<<endl;
			return size(son(rot, 0))-1;
		}
	}mt;
	
	void ins(int d, int c) {
		if (rec[c]&&d>rec[c]) return ;
		if (rec[c]) mt.remove(rec[c], c);
		mt.ins(d, c);
		rec[c]=d;
	}
	void remove(int d, int c) {
		assert(d==rec[c]);
		mt.remove(d, c);
	}
	
	void dfs(int u) {
		siz[u]=1;
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			dep[v]=dep[u]+1; dfs(v);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void info(int u, bool op) {
		if (op && (!rec[c[u]]||rec[c[u]]>dep[u])) ins(dep[u], c[u]), rec[c[u]]=dep[u]; //, printf("at %d ins dep:%d col:%d\n", u, dep[u], c[u]);
		else if (!op && rec[c[u]]==dep[u]) remove(dep[u], c[u]), rec[c[u]]=0; //, printf("at %d remove dep:%d col:%d\n", u, dep[u], c[u]);
		//else if (!op) printf("at %d op=0 but rec=%d, dep[u]=%d\n", u, rec[c[u]], dep[u]);
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v!=nowson) info(v, op);
		}
	}
	void dsu(int u, bool op) {
		//cout<<"dsu "<<u<<' '<<op<<endl;
		for (int i=head[u],v; i; i=e[i].next) {
			v = e[i].to;
			if (v!=mson[u]) dsu(v, 0);
		}
		if (mson[u]) dsu(mson[u], 1), nowson=mson[u];
		info(u, 1);
		//cout<<"at "<<u<<" try to get ans"<<endl;
		//for (int i=h2[u],v; i; i=e2[i].next) ans[e2[i].rk]=mt.query(dep[u]+e2[i].d)-mt.query(dep[u]-1); //, cout<<"query "<<mt.query(dep[u]+e2[i].d)<<' '<<mt.query(dep[u]-1)<<endl;
		for (int i=h2[u],v; i; i=e2[i].next) ans[e2[i].rk]=mt.query(dep[u]+e2[i].d);
		nowson=0;
		if (!op) info(u, 0);
	}
	void solve() {
		e2 = new edge2[m+10];
		for (int i=1; i<=m; ++i) add2(qu[i], qd[i], i);
		dep[1]=1;
		dfs(1);
		dsu(1, 1);
		for (int i=1; i<=m; ++i) printf("%d\n", ans[i]);
		exit(0);
	}
}

signed main()
{
	#ifdef DEBUG
	freopen("1.in", "r", stdin);
	#endif
	bool flag=1;
	
	n=read(); m=read();
	e = new edge[n+10];
	for (int i=1; i<=n; ++i) c[i]=read();
	for (int i=1; i<n; ++i) add(read(), i+1);
	for (int i=1; i<=m; ++i) {
		qu[i]=read(); qd[i]=read();
		if (qd[i]<n) flag=0;
	}
	//if (flag) task1::solve();
	//else force::solve();
	task::solve();	

	return 0;
}

原文地址:https://www.cnblogs.com/narration/p/15020833.html