NIO is playing a game about trees.
The game has two trees A,BA, BA,B each with NNN vertices. The vertices in each tree are numbered from 111 to NNN and the iii-th vertex has the weight viv_ivi. The root of each tree is vertex 1. Given KKK key numbers x1,…,xkx_1,\dots,x_kx1,…,xk, find the number of solutions that remove exactly one number so that the weight of the lowest common ancestor of the vertices in A with the remaining key numbers is greater than the weight of the lowest common ancestor of the vertices in B with the remaining key numbers.
The first line has two positive integers N,K(2≤K≤N≤105)N,K (2 \leq K \leq N \leq 10^5)N,K(2≤K≤N≤105).
The second line has KKK unique positive integers x1,…,xK(xi≤N)x_1,\dots,x_K (x_i \leq N)x1,…,xK(xi≤N).
The third line has NNN positive integers ai(ai≤109)a_i (a_i \leq 10^9)ai(ai≤109) represents the weight of vertices in A.
The fourth line has N−1N - 1N−1 positive integers {pai}\{pa_i\}{pai}, indicating that the number of the father of vertices i+1i+1i+1 in tree A is paipa_ipai.
The fifth line has nnn positive integers bi(bi≤109)b_i (b_i \leq 10^9)bi(bi≤109) represents the weight of vertices in B.
The sixth line has N−1N - 1N−1 positive integers {pbi}\{pb_i\}{pbi}, indicating that the number of the father of vertices i+1i+1i+1 in tree B is pbipb_ipbi.
One integer indicating the answer.
示例1
5 3 5 4 3 6 6 3 4 6 1 2 2 4 7 4 5 7 7 1 1 3 2
1
In first case, the key numbers are 5,4,3. Remove key number 5, the lowest common ancestors of the vertices in A with the remaining key numbers is 2, in B is 3. Remove key number 4, the lowest common ancestors of the vertices in A with the remaining key numbers is 2, in B is 1. Remove key number 3, the lowest common ancestors of the vertices in A with the remaining key numbers is 4, in B is 1. Only remove key number 5 satisfies the requirement.
示例2
10 3 10 9 8 8 9 9 2 7 9 0 0 7 4 1 1 2 4 3 4 2 4 7 7 7 2 3 4 5 6 1 5 3 1 1 3 1 2 4 7 3 5
2
题意: 给出两颗树A和B,点的编号都是从1~n,每棵树上的点都有各自的价值wi,再给出一个包含k个点的关键点集合,现在可以移除关键点集合中的一个点,使得剩下关键点在A树上的LCA的点权大于在B树上的LCA的点权,求多少种方案符合要求。
分析: 比赛的时候想到的是比较麻烦的模拟做法,根据树的形态来分类讨论,其实这样是非常麻烦的,赛后看到了正解,正解是维护前缀LCA和后缀LCA,这样就能很快求出剩下点的LCA了,时间复杂度为O(nlogn)。具体做法就是先建好两棵树,然后对于A树维护一个关键点的前缀LCA数组front[i][1],再维护一个后缀LCA数组back[i][1],对于B树同理得到front[i][2],back[i][2],然后就可以枚举移除哪个关键点了,剩下点在A树中的LCA就是lca(front[i-1][1], back[i+1][1]),在B树中的LCA就是lca(front[i-1][2], back[i+1][2]),注意边界的特判。
具体代码如下:
正解:
- #include
- #include
- #include
- #include
- #include
- #include
- using namespace std;
-
- vector<int> g[100005][3];
- int n, k, w[100005][3], fa[100005][21][3], dep[100005][3], a[100005];
- int front[100005][3], back[100005][3];//前缀lca和后缀lca
-
- void dfs(int now, int pre, int type)
- {
- dep[now][type] = dep[pre][type]+1;
- fa[now][0][type] = pre;
- for(int i = 1; i <= 20; i++)//超出范围的祖先都是0号结点
- fa[now][i][type] = fa[fa[now][i-1][type]][i-1][type];
- for(int i = 0; i < g[now][type].size(); i++)
- if(g[now][type][i] != pre)
- dfs(g[now][type][i], now, type);
- }
-
- int lca(int x, int y, int type)
- {
- if(dep[x][type] < dep[y][type]) swap(x, y);
- for(int i = 20; i >= 0; i--)
- if(dep[fa[x][i][type]][type] >= dep[y][type])
- x = fa[x][i][type];
- if(x == y) return x;
- for(int i = 20; i >= 0; i--)
- if(fa[x][i][type] != fa[y][i][type])
- x = fa[x][i][type], y = fa[y][i][type];
- return fa[x][0][type];
- }
-
- signed main()
- {
- cin >> n >> k;
- for(int i = 1; i <= k; i++) scanf("%d", &a[i]);
- for(int i = 1; i <= n; i++) scanf("%d", &w[i][1]);
- for(int i = 2; i <= n; i++){
- int u;
- scanf("%d", &u);
- g[u][1].push_back(i);
- }
- for(int i = 1; i <= n; i++) scanf("%d", &w[i][2]);
- for(int i = 2; i <= n; i++){
- int u;
- scanf("%d", &u);
- g[u][2].push_back(i);
- }
- dfs(1, 0, 1);
- front[0][1] = a[1];
- back[k+1][1] = a[k];
- for(int i = 1; i <= k; i++) front[i][1] = lca(front[i-1][1], a[i], 1);
- for(int i = k; i >= 1; i--) back[i][1] = lca(back[i+1][1], a[i], 1);
- dfs(1, 0, 2);
- front[0][2] = a[1];
- back[k+1][2] = a[k];
- for(int i = 1; i <= k; i++) front[i][2] = lca(front[i-1][2], a[i], 2);
- for(int i = k; i >= 1; i--) back[i][2] = lca(back[i+1][2], a[i], 2);
- int ans = 0;
- for(int i = 1; i <= k; i++){//枚举移除哪个点
- int top1, top2;
- if(i < k && i > 1){
- top1 = lca(front[i-1][1], back[i+1][1], 1);
- top2 = lca(front[i-1][2], back[i+1][2], 2);
- }
- else if(i == 1){
- top1 = back[2][1];
- top2 = back[2][2];
- }
- else{
- top1 = front[k-1][1];
- top2 = front[k-1][2];
- }
- if(w[top1][1] > w[top2][2]) ans++;
- }
- printf("%d\n", ans);
- return 0;
- }
-
-
比赛时的模拟做法:
- #include
- #include
- #include
- #include
- #include
- #include
- #include
- #define inf 0x3f3f3f3f
- using namespace std;
-
- int a[100005], w[100005][3];
- int num[100005][3];
- bool flag[100005];
- int dep[100005][3], fa[100005][21][3];
- int F[100005][3];
- vector<int> tr[100005][3];
- int n, k;
- vector<int> temp;
-
- void dfs(int now, int pre, int type){
- dep[now][type] = dep[pre][type]+1;
- fa[now][0][type] = pre;
- for(int i = 1; i <= 20; i++)//超出范围的祖先都是0号结点
- fa[now][i][type] = fa[fa[now][i-1][type]][i-1][type];
- for(int i = 0; i < tr[now][type].size(); i++)
- if(tr[now][type][i] != pre)
- dfs(tr[now][type][i], now, type);
- }
-
- int lca(int x, int y, int type){
- if(dep[x][type] < dep[y][type]) swap(x, y);
- for(int i = 20; i >= 0; i--)
- if(dep[fa[x][i][type]][type] >= dep[y][type])
- x = fa[x][i][type];
- if(x == y) return x;
- for(int i = 20; i >= 0; i--)
- if(fa[x][i][type] != fa[y][i][type])
- x = fa[x][i][type], y = fa[y][i][type];
- return fa[x][0][type];
- }
-
- void dfs2(int now, int fa, int type, int super_fa){
- if(flag[now]){
- num[now][type]++;
- F[now][type] = super_fa;
- }
- for(int i = 0; i < tr[now][type].size(); i++){
- int to = tr[now][type][i];
- dfs2(to, now, type, super_fa);
- num[now][type] += num[to][type];
- }
- }
-
- void dfs3(int now, int type){
- if(flag[now]) temp.push_back(now);
- for(int i = 0; i < tr[now][type].size(); i++){
- int to = tr[now][type][i];
- dfs3(to, type);
- }
- }
-
- signed main()
- {
- scanf("%d%d", &n, &k);
- for(int i = 1; i <= k; i++){
- scanf("%d", &a[i]);
- flag[a[i]] = true;
- }
- for(int i = 1; i <= n; i++) scanf("%d", &w[i][1]);
- for(int i = 2; i <= n; i++){
- int to;
- scanf("%d", &to);
- tr[to][1].push_back(i);
- }
- for(int i = 1; i <= n; i++) scanf("%d", &w[i][2]);
- for(int i = 2; i <= n; i++){
- int to;
- scanf("%d", &to);
- tr[to][2].push_back(i);
- }
- dfs(1, 0, 1);
- dfs(1, 0, 2);
- int lca1 = a[1], lca2 = a[1];
- for(int i = 2; i <= k; i++){
- lca1 = lca(a[i], lca1, 1);
- lca2 = lca(a[i], lca2, 2);
- }
- int cnt1 = 0;
- for(int i = 0; i < tr[lca1][1].size(); i++){
- int to = tr[lca1][1][i];
- dfs2(to, lca1, 1, to);
- if(num[to][1]) cnt1++;
- }
- int cnt2 = 0;
- for(int i = 0; i < tr[lca2][2].size(); i++){
- int to = tr[lca2][2][i];
- dfs2(to, lca2, 2, to);
- if(num[to][2]) cnt2++;
- }
- int ans = 0;
- for(int i = 1; i <= k; i++){
- int t1 = -1, t2 = -1;
- if(cnt1 > 2) t1 = w[lca1][1];
- else if(a[i] == lca1){
- if(cnt1 == 2) t1 = w[lca1][1];
- else if(cnt1 == 1){
- temp.clear();
- for(int j = 1; j <= k; j++){
- if(a[j] == lca1) continue;
- temp.push_back(a[j]);
- }
- int top = temp[0];
- for(int j = 0; j < temp.size(); j++)
- top = lca(top, temp[j], 1);
- t1 = w[top][1];
- }
- //此时cnt1不可能为0
- }
- else if(cnt1 == 2){
- if(num[F[a[i]][1]][1] > 1) t1 = w[lca1][1];//lca不变
- else{
- temp.clear();
- for(int j = 1; j <= k; j++){
- if(a[j] == a[i]) continue;
- temp.push_back(a[j]);
- }
- int top = temp[0];
- for(int j = 0; j < temp.size(); j++)
- top = lca(top, temp[j], 1);
- t1 = w[top][1];
- }
- }
- else if(cnt1 == 1)
- t1 = w[lca1][1];
- if(cnt2 > 2) t2 = w[lca2][2];
- else if(a[i] == lca2){
- if(cnt2 == 2) t2 = w[lca2][2];
- else if(cnt2 == 1){
- temp.clear();
- for(int j = 1; j <= k; j++){
- if(a[j] == lca2) continue;
- temp.push_back(a[j]);
- }
- int top = temp[0];
- for(int j = 0; j < temp.size(); j++)
- top = lca(top, temp[j], 2);
- t2 = w[top][2];
- }
- //此时cnt1不可能为0
- }
- else if(cnt2 == 2){
- if(num[F[a[i]][2]][2] > 1) t2 = w[lca2][2];//lca不变
- else{
- temp.clear();
- for(int j = 1; j <= k; j++){
- if(a[j] == a[i]) continue;
- temp.push_back(a[j]);
- }
- int top = temp[0];
- for(int j = 0; j < temp.size(); j++)
- top = lca(top, temp[j], 2);
- t2 = w[top][2];
- }
- }
- else if(cnt2 == 1)
- t2 = w[lca2][2];
- if(t1 > t2) ans++;
- }
- printf("%d\n", ans);
- return 0;
- }