題目描述
給定一棵樹,設計數據結構支持以下操作 1 u v d 表示將路徑 (u,v) 加d 2 u v 表示詢問路徑 (u,v) 上點權絕對值的和
輸入
第一行兩個整數n和m,表示結點個數和操作數
接下來一行n個整數a_i,表示點i的權值接下來n-1行,每行兩個整數u,v表示存在一條(u,v)的邊接下來m行,每行一個操作,輸入格式見題目描述
輸出
對于每個詢問輸出答案
樣例輸入
4 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4
樣例輸出
10
13
9
13
9
提示
對于100%的數據,n,m <= 10^5 且 0<= d,|a_i|<= 10^8
?
如果都是正數直接樹鏈剖分+線段樹就行了。
現在有了負數,那不是再維護一個區間正數個數就好了?顯然是不夠的。
因為區間修改時會把一些負數變為正數,會改變區間正數的個數,所以我們要維護區間三個值:
1、區間絕對值之和
2、區間非負數個數
3、區間最大的負數
當每次修改一個區間時如果這個區間的最大負數會變成非負數,那么說明這個區間的非負數個數會改變,因此要重構這個區間。
怎么重構呢?
對于這個區間的左右子區間,對于不需要重構的子區間下傳標記,對于需要重構的子區間就遞歸重構下去。
因為每個數最多只會被重構一次,因此重構均攤O(nlogn)。總時間復雜度還是O(mlogn)級別。
#include<set>
#include<map>
#include<stack>
#include<queue>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int num[800010];
int mx[800010];
ll sum[800010];
int d[100010];
int f[100010];
int son[100010];
int size[100010];
int top[100010];
int to[200010];
int tot;
int head[100010];
int s[100010];
int q[100010];
int n,m;
int x,y,z;
int opt;
int cnt;
ll a[800010];
int next[200010];
int v[100010];
int merge(int x,int y)
{if(x<0&&y<0){return max(x,y);}if(x<0){return x;}if(y<0){return y;}return 0;
}
void add(int x,int y)
{tot++;next[tot]=head[x];head[x]=tot;to[tot]=y;
}
void dfs(int x)
{size[x]=1;d[x]=d[f[x]]+1;for(int i=head[x];i;i=next[i]){if(to[i]!=f[x]){f[to[i]]=x;dfs(to[i]);size[x]+=size[to[i]];if(size[to[i]]>size[son[x]]){son[x]=to[i];}}}
}
void dfs2(int x,int tp)
{s[x]=++cnt;top[x]=tp;q[cnt]=x;if(son[x]){dfs2(son[x],tp);}for(int i=head[x];i;i=next[i]){if(to[i]!=f[x]&&to[i]!=son[x]){dfs2(to[i],to[i]);}}
}
void pushup(int rt)
{num[rt]=num[rt<<1]+num[rt<<1|1];sum[rt]=sum[rt<<1]+sum[rt<<1|1];mx[rt]=merge(mx[rt<<1],mx[rt<<1|1]);
}
void pushdown(int rt,bool x,bool y,int l,int r)
{if(a[rt]){int mid=(l+r)>>1;if(x){if(mx[rt<<1]){mx[rt<<1]+=a[rt];}sum[rt<<1]+=1ll*(2*num[rt<<1]-(mid-l+1))*a[rt];a[rt<<1]+=a[rt];}if(y){if(mx[rt<<1|1]){mx[rt<<1|1]+=a[rt];}sum[rt<<1|1]+=1ll*(2*num[rt<<1|1]-(r-mid))*a[rt];a[rt<<1|1]+=a[rt];}a[rt]=0;}
}
void build(int rt,int l,int r)
{if(l==r){if(v[q[l]]<0){mx[rt]=v[q[l]];}else{num[rt]=1;}sum[rt]=abs(v[q[l]]);return ;}int mid=(l+r)>>1;build(rt<<1,l,mid);build(rt<<1|1,mid+1,r);pushup(rt);
}
void rebuild(int rt,int l,int r,ll c)
{if(l==r){num[rt]=1;sum[rt]=mx[rt]+c;mx[rt]=0;return ;}int mid=(l+r)>>1;c+=a[rt];a[rt]=c;if(mx[rt<<1]&&mx[rt<<1]+c>=0&&mx[rt<<1|1]&&mx[rt<<1|1]+c>=0){a[rt]=0;rebuild(rt<<1,l,mid,c);rebuild(rt<<1|1,mid+1,r,c);}else if(mx[rt<<1]&&mx[rt<<1]+c>=0){pushdown(rt,0,1,l,r);rebuild(rt<<1,l,mid,c);}else if(mx[rt<<1|1]&&mx[rt<<1|1]+c>=0){pushdown(rt,1,0,l,r);rebuild(rt<<1|1,mid+1,r,c);}pushup(rt);
}
void change(int rt,int l,int r,int L,int R,int c)
{if(L<=l&&r<=R){if(mx[rt]+c>=0&&mx[rt]){rebuild(rt,l,r,c);}else{if(mx[rt]){mx[rt]+=c;}a[rt]+=c;sum[rt]+=1ll*(2*num[rt]-(r-l+1))*c;}return ;}int mid=(l+r)>>1;pushdown(rt,1,1,l,r);if(L<=mid){change(rt<<1,l,mid,L,R,c);}if(R>mid){change(rt<<1|1,mid+1,r,L,R,c);}pushup(rt);
}
ll query(int rt,int l,int r,int L,int R)
{if(L<=l&&r<=R){return sum[rt];}pushdown(rt,1,1,l,r);int mid=(l+r)>>1;long long res=0;if(L<=mid){res+=query(rt<<1,l,mid,L,R);}if(R>mid){res+=query(rt<<1|1,mid+1,r,L,R);}return res;
}
void updata(int x,int y,int z)
{while(top[x]!=top[y]){if(d[top[x]]<d[top[y]]){swap(x,y);}change(1,1,n,s[top[x]],s[x],z);x=f[top[x]];}if(d[x]>d[y]){swap(x,y);}change(1,1,n,s[x],s[y],z);
}
ll downdata(int x,int y)
{ll res=0;while(top[x]!=top[y]){if(d[top[x]]<d[top[y]]){swap(x,y);}res+=query(1,1,n,s[top[x]],s[x]);x=f[top[x]];}if(d[x]>d[y]){swap(x,y);}res+=query(1,1,n,s[x],s[y]);return res;
}
int main()
{scanf("%d%d",&n,&m);for(int i=1;i<=n;i++){scanf("%d",&v[i]);}for(int i=1;i<n;i++){scanf("%d%d",&x,&y);add(x,y);add(y,x);}dfs(1);dfs2(1,1);build(1,1,n);while(m--){scanf("%d",&opt);scanf("%d%d",&x,&y);if(opt==1){scanf("%d",&z);updata(x,y,z);}else{printf("%lld\n",downdata(x,y));}}
}