一种\(n^2log^2n\)的做法是可以想到的:点分治,使用multiset来维护以某个节点为根的某一些链上的颜色信息。合并的时候\(O(nlogn)\)合并。
我们略微计算了一下复杂度,感觉这是卡不过去的。
仔细想想我们发现计算一条路径上不同的颜色数量是较为困难的,故而我们可以转而统计不同颜色对\(ans_i\)的贡献。
我们仔细思考,发现,对于以一个点\(rt\)为根的子树的某个点\(i\),其颜色是从\(i\)到\(rt\)的路径上首次出现的,那么我们在计算任意一个其他子树内的点\(j\)答案时,可以将答案减去\(sz_i\),然后再容斥掉\(j\)到\(rt\)路径上存在\(a_i\)这种颜色的情况。
故而,对于每一次分治要处理的以\(rt\)为根的树,我们可以先预处理出\(sm\),表示的是这棵树中所有满足「到根节点路径上没有与其颜色相同的点」的点的子树大小之和;以及\(val_i\),表示的是颜色为\(i\)的到根节点路上没有颜色为\(i\)的节点的子树大小之和。然后我们需要对以它的每一个子节点为根的子树统计答案。
显而易见的,\(j\)到\(rt\)的路径上经过的颜色是不应该被重复统计的,所以我们要让\(sm\)减去\(\sum val_{i}\),其中\(i\)是路径上经过的颜色。
此外,我们还需要统计这些颜色对答案的贡献。令这些颜色的数量为\(nm\),那么它们对\(j\)的答案的贡献就应该是\(nm*dlt\),其中\(dlt\)是\(rt\)除了当前正在找的子节点以外的子节点的子树大小之和。
然后,我们还需要统计以根节点为路径端点的答案数量——这种答案是上述方法未统计到的。所以,我们需要将根节点的答案减去\(val_{a_{rt}}\),然后再加上以根节点为根的所有路径数量,也就是\(sz_{rt}\)
#include<iostream>
#include<cstdio>
#include<set>
#define Fv(i,X) for(int i=h[X];i;i=e[i].nxt)
typedef long long ll;
inline int Max(int A,int B){
return A>B?A:B;
}
inline ll Max(ll A,ll B){
return A>B?A:B;
}
struct ee{
int v;
int nxt;
}e[200005];
int h[100005],et=0;
inline void Eadd(int U,int V){
e[++et]=(ee){V,h[U]};
h[U]=et;
}
inline void add(int U,int V){
Eadd(U,V);
Eadd(V,U);
}
int n,a[100005];
int s,rt=0;
int vis[100005];
ll sz[100005],mx[100005],ans[100005];
inline void dfs0(int X,int FA){
sz[X]=1,mx[X]=0;
Fv(i,X){
if(e[i].v==FA||vis[e[i].v]){
continue;
}
dfs0(e[i].v,X);
sz[X]+=sz[e[i].v];
mx[X]=Max(mx[X],sz[e[i].v]);
}
mx[X]=Max(mx[X],s-sz[X]);
if(mx[X]<mx[rt]){
rt=X;
}
}
ll cnt[100005],val[100005],sm=0,nm=0,dlt;
inline void dfs2(int X,int FA){
sz[X]=1;
++cnt[a[X]];
Fv(i,X){
if(e[i].v==FA||vis[e[i].v]){
continue;
}
dfs2(e[i].v,X);
sz[X]+=sz[e[i].v];
}
if(cnt[a[X]]==1){
sm+=sz[X];
val[a[X]]+=sz[X];
}
--cnt[a[X]];
}
inline void dfs3(int X,int FA,int TYP){
++cnt[a[X]];
Fv(i,X){
if(e[i].v==FA||vis[e[i].v]){
continue;
}
dfs3(e[i].v,X,TYP);
}
if(cnt[a[X]]==1){
sm+=sz[X]*TYP;
val[a[X]]+=sz[X]*TYP;
}
--cnt[a[X]];
}
inline void dfs4(int X,int FA){
++cnt[a[X]];
if(cnt[a[X]]==1){
++nm;
sm-=val[a[X]];
}
ans[X]+=sm+nm*dlt;
Fv(i,X){
if(e[i].v==FA||vis[e[i].v]){
continue;
}
dfs4(e[i].v,X);
}
if(cnt[a[X]]==1){
--nm;
sm+=val[a[X]];
}
--cnt[a[X]];
}
inline void dfs5(int X,int FA){
val[a[X]]=cnt[a[X]]=0;
Fv(i,X){
if(e[i].v==FA||vis[e[i].v]){
continue;
}
dfs5(e[i].v,X);
}
}
inline void calc(int X){
sm=nm=0;
dfs2(X,0);
ans[rt]+=sm-val[a[X]]+sz[X];
Fv(i,X){
if(vis[e[i].v]){
continue;
}
++cnt[a[X]];
sm-=sz[e[i].v];
val[a[X]]-=sz[e[i].v];
dfs3(e[i].v,X,-1);
--cnt[a[X]];
dlt=sz[X]-sz[e[i].v];
dfs4(e[i].v,X);
++cnt[a[X]];
sm+=sz[e[i].v];
val[a[X]]+=sz[e[i].v];
dfs3(e[i].v,X,1);
--cnt[a[X]];
}
dfs5(X,0);
}
inline void dfs1(int X){
vis[X]=1;
calc(X);
Fv(i,X){
if(vis[e[i].v]){
continue;
}
s=sz[X],rt=0;
dfs0(e[i].v,X);
dfs1(rt);
}
}
void init(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
int u,v;
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);
}
rt=0,mx[0]=n;
dfs0(1,0);
dfs1(rt);
for(int i=1;i<=n;++i){
printf("%lld\n",ans[i]);
}
}
int main(){
init();
return 0;
}