题意:有一棵树,每一个节点都有一个值val[i]表示这个点有多少只晶蝶,t[i]表示假设你到达该点连接的点之后,经过多少秒该点的晶蝶都会跑光(从一个点到另一个点花费时间为1).问最多可以抓多少晶蝶
f[i][0]表示所有的子节点的值得不到的时候的子树答案最大值
f[i][1]表示子树答案(取当前点)最大值
f[x][0]的计算方法很简单,直接Σ(f[j][1]-val[j])即可,意为我们取
这个点的子树的答案的最大值,但是不取当前的这个点.
我们分情况讨论,分为两种情况,一种是我们在当前节点的下一层只得到
一个点的值,另一种是当t[j]==3时我们先去的到另外一个点的权值再拐
弯回到这个j点往下走的情况.
下面分类讨论:
1.当前父节点x的下一层子节点只获得一个点的情况,这个结果表示为
f[x][1]=max(f[x][0]+val[y],f[x][1]),意为假设该节点的子节
点都不取,然后取一个点的值的情况,这种情况其实包含了几种小情况,
(1)获取当前点值直接往下走,(2)取了同层的一个点之后在拐弯过来
不获取当前点的值往下走.但是由于是对于情况取max所以两种小情况
都包含了.
2.x的子节点j的t[j]==3,我们先去拐弯取一个点在回来按照j点往下
走,和上面的拐弯一样要注意,拐弯处那个点的子节点的值我们都得不
到,就可以按照上面的f[i][0]算.所以我们这里的状态转移方程式就
是当t[j]==3
f[x][1]=max(f[x][0]+val[j]+f[y][0]-f[y][1]+val[y],f[x][1])
这里可以知道f[x][0]+val[j]是固定的值,那么只需要求
temp=f[y][0]-f[y][1]+val[y]的最大值(j表示拐弯后往下那个点,y表
示拐弯处的点),知道temp的最大值时候还要注意假设temp的这个t[y]==3
的情况,我们就要维护一个最大值和一个次大值去求最值
- #include<bits/stdc++.h>
- #define ll long long
- using namespace std;
- const int N =1e5+10,mod=998244353;
- int val[N],t[N],h[N],tot=0;
- ll f[N][2];
- struct node
- {
- int to,ne;
- }edge[2*N];
- void add(int x,int y)
- {
- edge[++tot].to=y;
- edge[tot].ne=h[x];
- h[x]=tot;
- }
- void dfs(int x,int fa)
- {
- f[x][1]=0;
- f[x][0]=val[x];
- int max1=0,max2=0;
- for(int i=h[x];i!=-1;i=edge[i].ne)
- {
- int j=edge[i].to;
- if(j==fa)
- continue;
- dfs(j,x);
- f[x][0]+=(f[j][1]-val[j]);
- int temp=f[j][0]-f[j][1]+val[j];
- if(temp>=max1)
- {
- max2=max1;
- max1=temp;
- }
- else if(temp>=max2)
- max2=temp;
- }
- f[x][1]=f[x][0];
- for(int i=h[x];i!=-1;i=edge[i].ne)
- {
- int j=edge[i].to;
- if(j==fa)
- continue;
- f[x][1]=max(f[x][0]+val[j],f[x][1]);
- if(t[j]==3)
- {
- if(f[j][0]-f[j][1]+val[j]==max1)
- {
- f[x][1]=max(f[x][0]+val[j]+max2,f[x][1]);
- }
- else
- {
- f[x][1]=max(f[x][0]+val[j]+max1,f[x][1]);
- }
- }
-
- }
- }
- void solve()
- {
- memset(h,-1,sizeof h);
- tot=0;
- int n,x,y;
- cin>>n;
- for(int i=1;i<=n;i++)
- cin>>val[i];
- for(int i=1;i<=n;i++)
- cin>>t[i];
- for(int i=1;i<n;i++)
- {
- cin>>x>>y;
- add(x,y);
- add(y,x);
- }
- dfs(1,0);
- cout<<f[1][1]<<endl;
- }
- signed main()
- {
- cin.tie(0);
- cout.tie(0);
- ios::sync_with_stdio(0);
- int t;
- cin>>t;
- while(t--)
- solve();
- return 0;
- }