今天學習了樹鏈剖分,記錄一下。
【題目背景】
HYSBZ - 1036樹的統計Count
【題目分析】
題目要求求任意結點之間路徑的和以及路徑上最大的結點,還有可能修改。如果正常做可能會很復雜(我也不知道正常應該怎么做,應該要用到LCA什么的,我還不太會)。
但是如果我們能夠用線段樹或者樹狀數組維護這個樹,那么這種問題就會變得很簡單。樹鏈剖分就是這樣一種將樹映射在一個數組上變成線性結構然后用線段樹進行維護的數據結構。
【基礎知識】
- 重兒子:兒子中子樹結點數目最多的那個兒子(size最大)
- 重邊:父親結點和重兒子連成的邊
- 重鏈:由多條重邊連接而成的路徑
- 輕兒子:除了重兒子的其他兒子
- 輕邊:父親和輕兒子連成的邊
如圖所示,紅圈的表示重兒子,黑邊表示重邊。由黑邊組成的鏈為重鏈。
【具體實現】
我們先進行一次遍歷得到重兒子以及深度等信息儲存起來
void dfs1(int u,int f)
{int i,v;siz[u]=1; //儲存該結點子樹的大小(最小只有自身一個結點)son[u]=0; //儲存重兒子fa[u]=f; //儲存父節點h[u]=h[f]+1;//儲存深度for(i=0;i<g[u].size();i++){v=g[u][i];if(v!=f){dfs1(v,u); //深度優先遍歷siz[u]+=siz[v];if(siz[son[u]]<siz[v]) son[u]=v;}}
}
得到以上數據后,我們可以按重鏈將樹映射在一個數組上。從根節點開始,優先將重鏈映射到數組上,然后按照深度依次進行輕兒子,輕兒子又是某一個重鏈的開始(每一個節點都處于一個且僅有一個重鏈中)。記錄每條每個節點所屬重鏈的開頭(從而判斷兩個節點是否在同一個重鏈上)。
void dfs2(int u,int f,int k)
{int i,v;top[u]=k; //記錄所屬重鏈的開頭pos[u]=++cnt;//映射到數組上的下標(同一個重鏈的下標是連續的)A[cnt]=val[u];//確定數組所對應節點的值方便進行維護if(son[u]) dfs2(son[u],u,k);//優先遍歷重兒子,從而得到連續的重鏈for(i=0;i<g[u].size();i++){v=g[u][i]; if(v!=f&&v!=son[u]) dfs2(v,u,v); //遍歷其他輕兒子}
}
成功將樹映射到數組上以后我們再用線段樹對數組進行維護。對于線段樹的維護是常規操作。
void update(int k,int l,int r,int x,int v)
{if(l==r){Sum[k]=Max[k]=v;return;}int mid=(l+r)/2;if(x<=mid) update(k<<1,l,mid,x,v);else update(k<<1|1,mid+1,r,x,v);Sum[k]=Sum[k<<1]+Sum[k<<1|1];Max[k]=max(Max[k<<1],Max[k<<1|1]);
}int QuerySum(int k,int l,int r,int L,int R)
{if(L<=l && r<=R) return Sum[k];int mid=(l+r)/2;int ret=0;if(L<=mid) ret+=QuerySum(k<<1,l,mid,L,R);if(R>mid) ret+=QuerySum(k<<1|1,mid+1,r,L,R);return ret;
}int QueryMax(int k,int l,int r,int L,int R)
{if(L==l && r==R) return Max[k];int mid=(l+r)/2;if(R<=mid) return QueryMax(k<<1,l,mid,L,R);else if(L>mid) return QueryMax(k<<1|1,mid+1,r,L,R);else return max(QueryMax(k<<1,l,mid,L,mid),QueryMax(k<<1|1,mid+1,r,mid+1,R));
}
重點還是對于樹上兩個點如何得到他們之間的一條路徑以及這個路徑在映射數組中的位置。我們每次從深度更深的點向上升,直到兩個節點處在同一條鏈中(或者處于同一節點處)。在上升的過程中記錄每條鏈的值(每條鏈都處于映射數組的一個連續的區間內)
int FindSum(int u,int v)
{int ans=0;while(top[u]!=top[v]){if(h[top[u]]<h[top[v]]) swap(u,v);ans+=QuerySum(1,1,n,pos[top[u]],pos[u]);u=fa[top[u]];}if(h[u]<h[v]) swap(u,v);ans+=QuerySum(1,1,n,pos[v],pos[u]);return ans;
}int FindMax(int u,int v)
{int ans=INT_MIN;while(top[u]!=top[v]){if(h[top[u]]<h[top[v]]) swap(u,v);ans=max(ans,QueryMax(1,1,n,pos[top[u]],pos[u]));u=fa[top[u]];}if(h[u]<h[v]) swap(u,v);ans=max(ans,QueryMax(1,1,n,pos[v],pos[u]));return ans;
}
這樣我們就成功做到用線段樹維護樹狀結構的數據啦
【AC代碼】
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<climits>using namespace std;const int MAXN=30010;
vector<int>g[MAXN];
int fa[MAXN],A[MAXN],val[MAXN],pos[MAXN],siz[MAXN],son[MAXN],h[MAXN],top[MAXN];
int cnt=0,n,m;
int Sum[MAXN<<2],Max[MAXN<<2];void dfs1(int u,int f)
{int i,v;siz[u]=1;son[u]=0;fa[u]=f;h[u]=h[f]+1;for(i=0;i<g[u].size();i++){v=g[u][i];if(v!=f){dfs1(v,u);siz[u]+=siz[v];if(siz[son[u]]<siz[v]) son[u]=v;}}
}
void dfs2(int u,int f,int k)
{int i,v;top[u]=k;pos[u]=++cnt;A[cnt]=val[u];if(son[u]) dfs2(son[u],u,k);for(i=0;i<g[u].size();i++){v=g[u][i];if(v!=f&&v!=son[u]) dfs2(v,u,v);}
}void update(int k,int l,int r,int x,int v)
{if(l==r){Sum[k]=Max[k]=v;return;}int mid=(l+r)/2;if(x<=mid) update(k<<1,l,mid,x,v);else update(k<<1|1,mid+1,r,x,v);Sum[k]=Sum[k<<1]+Sum[k<<1|1];Max[k]=max(Max[k<<1],Max[k<<1|1]);
}int QuerySum(int k,int l,int r,int L,int R)
{if(L<=l && r<=R) return Sum[k];int mid=(l+r)/2;int ret=0;if(L<=mid) ret+=QuerySum(k<<1,l,mid,L,R);if(R>mid) ret+=QuerySum(k<<1|1,mid+1,r,L,R);return ret;
}int QueryMax(int k,int l,int r,int L,int R)
{if(L==l && r==R) return Max[k];int mid=(l+r)/2;if(R<=mid) return QueryMax(k<<1,l,mid,L,R);else if(L>mid) return QueryMax(k<<1|1,mid+1,r,L,R);else return max(QueryMax(k<<1,l,mid,L,mid),QueryMax(k<<1|1,mid+1,r,mid+1,R));
}int FindSum(int u,int v)
{int ans=0;while(top[u]!=top[v]){if(h[top[u]]<h[top[v]]) swap(u,v);ans+=QuerySum(1,1,n,pos[top[u]],pos[u]);u=fa[top[u]];}if(h[u]<h[v]) swap(u,v);ans+=QuerySum(1,1,n,pos[v],pos[u]);return ans;
}int FindMax(int u,int v)
{int ans=INT_MIN;while(top[u]!=top[v]){if(h[top[u]]<h[top[v]]) swap(u,v);ans=max(ans,QueryMax(1,1,n,pos[top[u]],pos[u]));u=fa[top[u]];}if(h[u]<h[v]) swap(u,v);ans=max(ans,QueryMax(1,1,n,pos[v],pos[u]));return ans;
}int main()
{int a,b,i;char s[10];scanf("%d",&n);for(i=1;i<n;i++){scanf("%d%d",&a,&b);g[a].push_back(b);g[b].push_back(a);}for(i=1;i<=n;i++) scanf("%d",&val[i]);dfs1(1,0);dfs2(1,0,1);for(i=1;i<=n;i++) update(1,1,n,i,A[i]);scanf("%d",&m);while(m--){scanf("%s%d%d",s,&a,&b);if(s[1]=='H') update(1,1,n,pos[a],b);else if(s[1]=='S') printf("%d\n",FindSum(a,b));else printf("%d\n",FindMax(a,b));}return 0;
}