題目
這真是一道神仙的一批的題目
定義\(s(i,j)\)表示從點\(i\)到點\(j\)經過的顏色數量
設
\[sum_i=\sum_{j=1}^ns(i,j)\]
求出所有的\(sum_i\)
考慮點分治
對于一個點我們用兩種方式來統計其答案
這個點作為分治重心時,分值區域內所有點到這個點貢獻
這個點不是分治重心的時候,當前分治區域內其他子樹到這個點的貢獻
第一種貢獻我們很好統計,點分治的時候把所有子樹遍歷一遍就好了
第二種就需要轉換一下思路了,我們不能直接求\(s(i,j)\)了,我們應該求某一種顏色一共被數了多少次
我們開一個桶\(tax\),\(tax[i]\)表示\(i\)這種顏色控制的大小一共是多少,也就是這個顏色會被多少個終點數到,我們可以通過提前遍歷好所有子樹得到這個信息
每次進入一棵子樹的時候,提前減掉這個子樹的貢獻,之后進入子樹\(dfs\)就好了,如果一旦出現一種新顏色,顯然這種顏色會被當前分治區域內所有點數上,更改一下貢獻即可
代碼
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#define maxn 100005
#define re register
#define inf 99999999
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9') c=getchar();while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
int col[maxn],head[maxn],vis[maxn],sum[maxn],mx[maxn];
int f[maxn],tax[maxn],d[maxn],st[maxn],tmp[maxn];
int num,n,m,now,S,rt,top;
LL ans,Ans[maxn],res;
std::vector<int> v[maxn],c[maxn];
inline void add(int x,int y) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;}
void getroot(int x,int fa) {sum[x]=1,mx[x]=0;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]||e[i].v==fa) continue;getroot(e[i].v,x);sum[x]+=sum[e[i].v],mx[x]=max(mx[x],sum[e[i].v]);}mx[x]=max(mx[x],S-sum[x]);if(mx[x]<now) now=mx[x],rt=x;
}
void getdis(int x,int fa,int now,int t) {if(!f[col[x]]) now++;if(!tmp[col[x]]) st[++top]=col[x];tmp[col[x]]=1;sum[x]=1;f[col[x]]++;Ans[t]+=now;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]||e[i].v==fa) continue;getdis(e[i].v,x,now,t);sum[x]+=sum[e[i].v];}if(f[col[x]]==1) d[col[x]]+=sum[x];f[col[x]]--;
}
void find(int x,int fa) {if(!f[col[x]]) ans-=tax[col[x]],ans+=res;Ans[x]+=ans;f[col[x]]++;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]||e[i].v==fa) continue;find(e[i].v,x);} if(f[col[x]]==1) ans-=res,ans+=tax[col[x]];f[col[x]]--;
}
void dfs(int x) {vis[x]=1;ans=0;f[col[x]]=1;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]) continue;top=0;getdis(e[i].v,0,1,x);for(re int j=1;j<=top;j++) if(st[j]!=col[x]) v[e[i].v].push_back(d[st[j]]),c[e[i].v].push_back(st[j]);for(re int j=1;j<=top;j++) if(st[j]!=col[x]) tax[st[j]]+=d[st[j]],ans+=d[st[j]];for(re int j=1;j<=top;j++) tmp[st[j]]=0,d[st[j]]=0;}f[col[x]]=0;ans+=S,tax[col[x]]=S;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]) continue;res=S-sum[e[i].v];ans-=sum[e[i].v],tax[col[x]]-=sum[e[i].v];for(re int j=0;j<v[e[i].v].size();j++) ans-=v[e[i].v][j],tax[c[e[i].v][j]]-=v[e[i].v][j];find(e[i].v,0);for(re int j=0;j<v[e[i].v].size();j++) ans+=v[e[i].v][j],tax[c[e[i].v][j]]+=v[e[i].v][j];ans+=sum[e[i].v],tax[col[x]]+=sum[e[i].v];}for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]) continue;for(re int j=0;j<v[e[i].v].size();j++)tax[c[e[i].v][j]]-=v[e[i].v][j];v[e[i].v].clear(),c[e[i].v].clear();}tax[col[x]]=0;for(re int i=head[x];i;i=e[i].nxt) {if(vis[e[i].v]) continue;now=inf,S=sum[e[i].v],getroot(e[i].v,0),dfs(rt);}
}
int main() {n=read();int x,y;for(re int i=1;i<=n;i++) col[i]=read();for(re int i=1;i<n;i++) x=read(),y=read(),add(x,y),add(y,x);S=n,now=inf,getroot(1,0);dfs(rt);for(re int i=1;i<=n;i++) printf("%lld\n",Ans[i]+1ll);return 0;
}