考虑这一题在链上的情况。
显然,可以建一棵线段树,修改是打标记下传,查询是左右相加特判左右相邻处颜色是否相同。
看起来在树上也可以这么做。但是仔细想想会发现不太对。
这是因为,根据树链剖分的原理,每两条剖开的链是独立查询的。
因此,在树上查询的时候,需要额外查询两个相接的链的接驳处的颜色是否相同。
这大概就做完了。
然后我们发现实时查询接驳处的颜色不仅很傻,而且容易错。
所以我们考虑开两个数组:lc和rc,储存一个区间内左端颜色和右端颜色。
另外就是链上接驳处的颜色问题。一种容易想到的方法是每一次记录它来的那个点,然后比较来的那个点和自身。
然而这么做同样是不仅傻而且容易错的。于是我们考虑另一种方法:每一次比较链顶和链顶的父亲这两个点,并将它的负贡献预先扣除。这样可以有效简化代码。
#include<iostream>
#include<cstdio>
#define Fv(i,X) for(int i=h[X];i;i=e[i].nxt)
inline void Swap(int &A,int &B){
A^=B^=A^=B;
}
int n,m;
const int N=100005;
struct ee{
int v;
int nxt;
}e[N<<1];
int h[N],et=0;
inline void Eadd(int U,int V){
e[++et]=(ee){V,h[U]};
h[U]=et;
}
inline void add(int U,int V){
Eadd(U,V);
Eadd(V,U);
}
int sn[N],sz[N],fa[N],dep[N];
inline void dfs0(int X,int FA){
fa[X]=FA;sz[X]=1;dep[X]=dep[FA]+1;sn[X]=0;
Fv(i,X){
if(e[i].v==FA){
continue;
}
dfs0(e[i].v,X);
if(sz[e[i].v]>sz[sn[X]]){
sn[X]=e[i].v;
}
sz[X]+=sz[e[i].v];
}
}
int dfn[N],tp[N],cnt=0;
inline void dfs1(int X,int FA,int TP){
dfn[X]=++cnt;tp[X]=TP;
if(sn[X]){
dfs1(sn[X],X,TP);
}
Fv(i,X){
if(e[i].v==FA||e[i].v==sn[X]){
continue;
}
dfs1(e[i].v,X,e[i].v);
}
}
int clr[N];
#define MID ((L+R)>>1)
#define LS (X<<1)
#define RS (X<<1|1)
int tr[N<<2],tg[N<<2],val[N<<2],lc[N<<2],rc[N<<2];
inline void mdf(int X,int C){
tr[X]=tg[X]=lc[X]=rc[X]=C;
val[X]=1;
}
inline void pshd(int X,int L,int R){
if(L==R){
return;
}
if(tg[X]){
mdf(LS,tg[X]),mdf(RS,tg[X]);
tg[X]=0;
}
}
inline int qryc(int X,int L,int R,int P){
if(L==R){
return tr[X];
}
pshd(X,L,R);
return P<=MID?qryc(LS,L,MID,P):qryc(RS,MID+1,R,P);
}
inline void updt(int X,int L,int R){
val[X]=(val[LS]+val[RS]-(rc[LS]==lc[RS]));
lc[X]=lc[LS],rc[X]=rc[RS];
}
inline void chg(int X,int L,int R,int A,int B,int C){
if(A<=L&&R<=B){
mdf(X,C);
return;
}
if(L>B||R<A){
return;
}
pshd(X,L,R);
chg(LS,L,MID,A,B,C);chg(RS,MID+1,R,A,B,C);
updt(X,L,R);
}
inline int qryv(int X,int L,int R,int A,int B){
if(A<=L&&R<=B){
return val[X];
}
if(L>B||R<A){
return 0;
}
pshd(X,L,R);
return qryv(LS,L,MID,A,B)+qryv(RS,MID+1,R,A,B)-((A<=MID&&B>MID)?(rc[LS]==lc[RS]):0);
}
void init(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%d",&clr[i]);
}
int u,v;
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);
}
sz[0]=0,dep[0]=0;
dfs0(1,0);
dfs1(1,0,1);
for(int i=1;i<=n;++i){
chg(1,1,n,dfn[i],dfn[i],clr[i]);
}
char ch[4];
int c,ans;
for(int i=1;i<=m;++i){
std::cin>>ch+1;
switch(ch[1]){
case 'Q':{
scanf("%d%d",&u,&v);
ans=0;
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]){
Swap(u,v);
}
ans+=qryv(1,1,n,dfn[tp[u]],dfn[u]);
ans-=(qryc(1,1,n,dfn[tp[u]])==qryc(1,1,n,dfn[fa[tp[u]]]));
u=fa[tp[u]];
}
if(dep[u]<dep[v]){
Swap(u,v);
}
ans+=qryv(1,1,n,dfn[v],dfn[u]);
printf("%d\n",ans);
break;
}
case 'C':{
scanf("%d%d%d",&u,&v,&c);
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]){
Swap(u,v);
}
chg(1,1,n,dfn[tp[u]],dfn[u],c);
u=fa[tp[u]];
}
if(dep[u]<dep[v]){
Swap(u,v);
}
chg(1,1,n,dfn[v],dfn[u],c);
break;
}
}
}
}
int main(){
init();
return 0;
}