lp3233 HNOI2014 世界树

让我们来分析这道题。
我们很容易地可以发现,如果一个节点是末端节点,那么它肯定是由它祖先中的第一个节点管辖更优。
由此我们继续推导,可以发现,假如只有两个关键点的话,它们之间的链上一定可以找到一个分界线,使得这条分界线的一段——以及所有与链上这一端相邻的节点——放在一个关键点上,并把剩下的放在另一个关键点上。这样是最优的。
然后我们考虑有三个关键点的情况。如果有三个关键点的话,问题似乎就复杂了很多,这是因为这条分界线必须是基于三个节点共同决定的。
我们换一个思路考虑。求出三个关键点两两的LCA。那么,这一分界线就必然可以在这六个点(最多六个)两两之间的链上找到。
存在一种叫做「虚树」的数据结构。这个数据结构本质上就是将所有关键点以及关键点两两之间的LCA建成一个数。
我们如此考虑,如果有一个点是关键点,那么它向下能够管辖的点的数量就是它的子树大小减去它的所有子节点能管辖的点数之和。
它向上能够管辖的点的数量则相对比较复杂。
我们考虑每一条边,假设我们已经预处理出了虚树上每一个点相邻最近的关键点,那么,如果这条边两端的点相邻最近的关键点是相同的,那么这条边(以及和这条边相邻的所有点)都划归到那个关键点下管辖。
如果这条边两段的点相邻最近的关键点不同,则需要使用倍增来确定这条边要如何被划分成两个部分。

接下来要考虑如何确定虚树上每一个点相邻最近的关键点。
我们不妨这样考虑:一个点相邻最近的关键点如果在它的下方,那么这可以通过树形DP求得;如果在它的上方或者其他子树内呢?
显然,如果这个点的相邻最近关键点在它的上方或者和它不在同一子树内,那么它的父亲的最近关键点一定与这个点的最近关键点相同。
于是,我们不妨从上而下DP,求得这个点在上方或者不在同一子树的最近关键点。

现在我们来梳理一下这一题的思路。
首先,我们进行预处理,处理出每个点的深度和dfn。
然后,我们进行两次树形DP,求出每个点的最近关键点。
第三,我们预处理出所有「末端节点」——也就是所有「是一个虚树上节点的直接儿子且子树内没有关键点的原树上的点」。这些点的贡献可以直接统计到它们父亲的最近关键点。
最后,我们依次考虑剩下的虚树边上的点。如果两段的最近关键点相同,那么就统计到那个最近关键点。否则就进行倍增寻找答案。

顺便讲一下如何建虚树。
我们将所有关键点按照dfs序排序,然后建一个栈,代表当前正在向下延长的链。
如果当前的点与栈顶的LCA的深度小等于次栈顶,那么就说明当前点不在栈顶的子树里,也就意味着栈顶处于当前应该维护的链外。
于是,我们需要就可以将新的这个点加入虚树和栈中。
否则,就说明原来的栈顶需要被弹出,那么就处理完它应该连的边然后将它弹出。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#define Fv(i,X) for(int i=h[X];i;i=e[i].nxt)
#define Fv2(i,X) for(int i=g[X];i;i=e[i].nxt)
using namespace std;

typedef pair<int,int> pii;
const int N=300005;
const int INF=0x3f3f3f3f;
struct ee{
	int v;
	int nxt;
}e[1200005];
int h[N],g[N],et=0;
inline void add(int *H,int U,int V){
	e[++et]=(ee){V,H[U]};
	H[U]=et;
}

int dfn[N],cnt=0,sz[N],dep[N],fa[N][20];
inline void dfs0(int X){
	dfn[X]=++cnt;sz[X]=1;
	Fv(i,X){
		if(e[i].v==fa[X][0]){
			continue;
		}
		fa[e[i].v][0]=X;dep[e[i].v]=dep[X]+1;
		dfs0(e[i].v);
		sz[X]+=sz[e[i].v];
	}
}
int usd[N],pt[N],ppt[N],hr[N],ans[N],up[N];
pii nr[N];
inline void dfs1(int X){
	nr[X]=(usd[X]?pii(0,X):pii(INF,0));
	Fv2(i,X){
		dfs1(e[i].v);
		nr[X]=min(nr[X],pii(nr[e[i].v].first+dep[e[i].v]-dep[X],nr[e[i].v].second));
	}
}
inline void dfs2(int X,int D,int P){
	if(pii(D,P)<nr[X]){
		nr[X]=pii(D,P);
	}else{
		D=nr[X].first,P=nr[X].second;	
	}
	Fv2(i,X){
		dfs2(e[i].v,D+dep[e[i].v]-dep[X],P);
	}
}
//子节点中所有子树中有虚树上节点的点都是不可以选取的。
//我们不妨逆向考虑,枚举每一个虚树上的子节点,然后从那个节点开始倍增,一直倍增到这棵树的子节点,然后把这些子节点的子树挖掉。 
inline void dfs3(int X){
	ans[hr[X]=nr[X].second]+=sz[X];
	Fv2(i,X){
		int nw=e[i].v;
		for(int j=18;j>=0;--j){
			if(fa[nw][j]&&dep[fa[nw][j]]>dep[X]){
				nw=fa[nw][j];
			}
		}
		ans[hr[X]]-=sz[up[e[i].v]=nw];
		dfs3(e[i].v);
	}
}
//现在剩下的末端节点就只有虚树上的节点了。 
//如果子节点的dfs序大于当前节点,那么分割点就偏上;否则偏下。 
inline void dfs4(int X){
	Fv2(i,X){
		if(hr[e[i].v]==hr[X]){
			ans[hr[X]]+=sz[up[e[i].v]]-sz[e[i].v];
		}else{
			int len=dep[hr[e[i].v]]+dep[X]-nr[X].first;
			len=((len&1)?(len+1)>>1:((len>>1)+(int)(hr[e[i].v]>hr[X])));
//			这里比较的是编号!!! 
			int nw=e[i].v;
			for(int j=18;j>=0;--j){
				if(dep[fa[nw][j]]>=len){
					nw=fa[nw][j];
				}
			}
			ans[hr[e[i].v]]+=sz[nw]-sz[e[i].v];
			ans[hr[X]]+=sz[up[e[i].v]]-sz[nw];
		}
		dfs4(e[i].v);
	}
}
inline void dfs5(int X){
	up[X]=hr[X]=0;
	Fv2(i,X){
		dfs5(e[i].v);
	}
	g[X]=0;
}
inline bool cmp(int A,int B){
	return dfn[A]<dfn[B];
}
inline int lca(int X,int Y){
	if(dep[X]<dep[Y]){
		swap(X,Y);
	}
	for(int i=18;i>=0;--i){
		if(dep[fa[X][i]]>=dep[Y]){
			X=fa[X][i];
		}
	}
	if(X==Y){
		return X;
	}
	for(int i=18;i>=0;--i){
		if(fa[X][i]!=fa[Y][i]){
			X=fa[X][i],Y=fa[Y][i];
		}
	}
	return fa[X][0];
}
int st[N],tp=0;
void init(){
	int n,Q;
	scanf("%d",&n);
	int u,v;
	for(int i=1;i<n;++i){
		scanf("%d%d",&u,&v);
		add(h,u,v);
		add(h,v,u);
	}
	dep[1]=1;
	dfs0(1);
	for(int j=1;j<=18;++j){
		for(int i=1;i<=n;++i){
			fa[i][j]=fa[fa[i][j-1]][j-1];
		}
	}
	scanf("%d",&Q);
	int m,X,Y;
	while(Q--){
		scanf("%d",&m);
		for(int i=1;i<=m;++i){
			scanf("%d",&pt[i]);
			ppt[i]=pt[i],usd[pt[i]]=1;
		}
		sort(pt+1,pt+1+m,cmp);
		st[tp=1]=pt[1];
		for(int i=2;i<=m;++i){
			X=pt[i],Y=lca(X,st[tp]);
			while(tp>1&&dep[Y]<=dep[st[tp-1]]){
				add(g,st[tp-1],st[tp]);
				--tp;
			}
			if(Y!=st[tp]){
				add(g,Y,st[tp]);st[tp]=Y;
			}
			st[++tp]=X;
		}
		while(tp>1){
			add(g,st[tp-1],st[tp]);
			--tp;
		}
		dfs1(st[1]);
		dfs2(st[1],nr[st[1]].first,nr[st[1]].second);
		dfs3(st[1]);
		dfs4(st[1]);
		ans[hr[st[1]]]+=sz[1]-sz[st[1]];
		for(int i=1;i<=m;++i){
			printf("%d ",ans[ppt[i]]);
		}
		puts("");
		dfs5(st[1]);
		for(int i=1;i<=m;++i){
			ans[pt[i]]=0,usd[pt[i]]=0;
		}
	}
}
int main(){
	init();
	return 0;
}