傳送門
先考慮一個貪心,對于一條邊來說,如果當前這個序列中在它的子樹中的元素個數為奇數個,那么這條邊就會被一組匹配經過,否則就不會
考慮反證法,如果在這條邊兩邊的元素個數都是偶數,那么至少有兩組匹配經過它,那么把這兩條路徑都刪去這條邊可以更優。如果兩邊是奇數,一定至少有一條路徑經過它,去掉這組匹配之后就變成了偶數的情況。證畢
然后是一個神仙的轉化,我們對于一顆子樹中的元素,在序列里標記為\(1\),否則為\(0\),那么這條邊出現次數就是序列中長度為偶數且區間和為奇數的區間個數
考慮用線段樹合并優化,對于每個節點,記\(t[p][0/1][0/1]\)表示節點\(p\)代表的區間中前綴和為偶數\(/\)奇數,下標為偶數\(/\)奇數的下標個數,然后線段樹合并就行了
然而咱還是搞不明白為啥線段樹上的區間要設為\([1,m+1]\)……有哪位知道為什么的請告訴咱一聲……
//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){R int res,f=1;R char ch;while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');return res*f;
}
const int N=1e5+5,M=N<<5,P=998244353;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){R int res=1;for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);return res;
}
struct eg{int v,nx,w;}e[N<<1];int head[N],tot;
inline void add_edge(R int u,R int v,R int w){e[++tot]={v,head[u],w},head[u]=tot;}
int sum[M],ls[M],rs[M],t[M][2][2],rt[N];
int n,m,ans,cnt,u,v,w;
void upd(int p,int l,int r){sum[p]=0;if(ls[p])sum[p]+=sum[ls[p]];if(rs[p])sum[p]+=sum[rs[p]];int x=ls[p]?sum[ls[p]]&1:0;fp(i,0,1)fp(j,0,1){t[p][i][j]=0;if(ls[p])t[p][i][j]+=t[ls[p]][i][j];if(rs[p])t[p][i][j]+=t[rs[p]][i^x][j];}int mid=(l+r)>>1;if(!ls[p])t[p][0][0]+=(mid>>1)-((l-1)>>1),t[p][0][1]+=((mid+1)>>1)-(l>>1);if(!rs[p])t[p][x][0]+=(r>>1)-(mid>>1),t[p][x][1]+=((r+1)>>1)-((mid+1)>>1);
}
void ins(int &p,int l,int r,int x){if(!p){p=++cnt;t[p][0][0]=(r>>1)-((l-1)>>1);t[p][0][1]=((r+1)>>1)-(l>>1);}if(l==r)return ++sum[p],void();int mid=(l+r)>>1;x<=mid?ins(ls[p],l,mid,x):ins(rs[p],mid+1,r,x);upd(p,l,r);
}
int merge(int x,int y,int l,int r){if(!x||!y)return x|y;int mid=(l+r)>>1;ls[x]=merge(ls[x],ls[y],l,mid);rs[x]=merge(rs[x],rs[y],mid+1,r);upd(x,l,r);return x;
}
void dfs(int u,int fa){go(u)if(v!=fa){dfs(v,u);ans=add(ans,mul(e[i].w,1ll*t[rt[v]][0][0]*t[rt[v]][1][0]%P+1ll*t[rt[v]][0][1]*t[rt[v]][1][1]%P));rt[u]=merge(rt[u],rt[v],1,m+1);}
}
int main(){
// freopen("testdata.in","r",stdin);n=read(),m=read();fp(i,1,n-1)u=read(),v=read(),w=read(),add_edge(u,v,w),add_edge(v,u,w);fp(i,1,m)u=read(),ins(rt[u],1,m+1,i);dfs(1,0);printf("%d\n",ans);return 0;
}