欢迎关注更多精彩
关注我,学习常用算法与数据结构,一题多解,降维打击。
本期话题:在树上查找2个结点的最近公共祖先
问题提出
最近公共祖先定义
最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远(深度最深)的那个。
问题
参考地址:https://www.luogu.com.cn/problem/P3379
给定一棵树,询问每两个结点的最近公共祖先,一般会询问多次。
朴素做法
- 利用dfs求出所有结点的深度和父亲结点。
- 查询时把深度大的结点往上移,直到两个结点深度一样。然后两个结点同时往上移,直到两结点相遇。
复杂度分析
第1步求深度和父亲结点,需要遍历所有结点,复杂度是O(n)。
第2步在极端情况下是O(n) , 在多次查询的情况下,效率很低。
空间换时间
试想一下我们给每1个结点分配1个空间来存储往上移n个位置到达的祖先结点。
当我们要查询两个公共祖先时,就可以使用二分查找的方法来加速。
以A, B为例,可以看到后面黄色部分是公共祖先,我们要找的是最左边的10号祖先。只要利用二分查找即可找到。
该方法可以把查询复杂度降低到log(n). 但同时空间复杂度是O(n^2)。
优化空间(倍增算法)
参考资料:https://oi-wiki.org//graph/lca/#%E5%80%8D%E5%A2%9E%E7%AE%97%E6%B3%95
上面的方法的问题是空间分配的太多了,而且仔细观察,空间是冗余的。
比如A往上1个的祖先分配的数组和A的数组是高度重合的,可以看出是有递归或继承关系的。而且我们每次都是取的数组的一半。
那么我们可以存储往上数2^n个的祖先。
即存储往上1个,2个,4个。。。的祖先分别是谁。
查询的时候,由于任意数字都可以用2进制进行组合而成,可以遍历到所有祖先。
具体算法可以类比二分算法。
代码模板
题目链接:https://www.luogu.com.cn/problem/P3379
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>using namespace std;const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];void initPara(int n)
{len = 0;for (int i = 0; i < n; i++){head[i] = -1;}
}void add(int a, int b)
{to[len] = b;nextEdge[len] = head[a];head[a] = len++;
}void dfs(int x, int fa)
{if (fa == -1) h[x] = 0;else {h[x] = h[fa] + 1;father[0][x] = fa;// 利用倍增算法初始化fatherfor (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {father[t][x] = father[t-1][father[t - 1][x]];}}int i;for (i = head[x]; i != -1; i = nextEdge[i]){int j = to[i];if (fa==j)continue;dfs(j, x);}
}int lca(int a, int b) {if (h[a] < h[b]) {return lca(b, a);}// 先将两个结点跳到一样高度int gap = h[a] - h[b];for (int t = bitL-1; t>=0; t--) {if (gap & (1 << t))a = father[t][a];}if (a == b)return a;gap = h[a];// 利用二分查找找到深度最低的且不一样的结点。for (int t = bitL-1; t >= 0; t--) {if (gap <=(1 << t))continue;if (father[t][a] == father[t][b])continue;a = father[t][a];b = father[t][b];gap -= 1 << t;}return father[0][a]; // 再往上1个既是公共祖先
}void solve()
{int t;int n, m, s;scanf("%d%d%d", &n, &m, &s);s--;initPara(n);int a, b;for (int i = 0; i < n - 1; ++i) {scanf("%d%d", &a, &b);a--, b--;add(a, b);add(b, a);}dfs(s, -1);/*for (int i = 0; i < n; ++i) {printf("%d: %d\n", i, h[i]);}*/while (m--) {scanf("%d%d", &a, &b);a--, b--;printf("%d\n", 1+lca(a, b));}
}void test() {int t;int n=5000, m=500000, s=1;//scanf("%d%d%d", &n, &m, &s);s--;initPara(n);int a, b;for (int i = 0; i < n - 1; ++i) {a = i, b = i + 1;add(a, b);add(b, a);}dfs(s, -1);//printf("%d\n", 1 + lca(10, 5000-1));while (m--) {a = (m+102)%n, b =( 3823+m*2)%n;//printf("%d\n", m);if(lca(a, b)!=min(a,b))printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);}
}int main()
{solve();//test();return 0;
}/*5 5 4
3 1
2 4
5 1
1 4
2 4
3 2
3 5
1 2
4 512 11 8
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 101 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
练习一
链接:https://loj.ac/p/10135
注意点:需要对结点进行编号,无公共祖先时返回-1
#include <stdio.h>
#include <malloc.h>
#include <string.h>
#include <cmath>
#include <algorithm>
#include <map>using namespace std;const int M = 500000 + 10;
const int N = 500000 + 10;map<int, int> num2Ind;
int indLen;void initIndex() {num2Ind.clear();indLen = 0;
}int getIndex(int n) {if (num2Ind.count(n) == 0)return -1;return num2Ind[n];
}int addIndex(int n) {if (num2Ind.count(n) == 0)num2Ind[n] = indLen++;return num2Ind[n];
}const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];void initPara(int n) {len = 0;for (int i = 0; i < n; i++) {head[i] = -1;}
}void add(int a, int b) {to[len] = b;nextEdge[len] = head[a];head[a] = len++;
}void dfs(int x, int fa) {if (fa == -1)h[x] = 0;else {h[x] = h[fa] + 1;father[0][x] = fa;for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {father[t][x] = father[t - 1][father[t - 1][x]];}}int i;for (i = head[x]; i != -1; i = nextEdge[i]) {int j = to[i];if (fa == j)continue;dfs(j, x);}
}int lca(int a, int b) {if (h[a] < h[b]) {return lca(b, a);}int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t))a = father[t][a];}if (a == b)return a;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;a = father[t][a];b = father[t][b];gap -= 1 << t;}return father[0][a];
}void solve() {int n, m;int a, b, s;scanf("%d", &n);initPara(n);for (int i = 0; i < n; ++i) {scanf("%d%d", &a, &b);if (b == -1) {s = addIndex(a);continue;}a = addIndex(a);b = addIndex(b);add(a, b);add(b, a);}dfs(s, -1);scanf("%d", &m);/*for (int i = 0; i < n; ++i) {printf("%d: %d\n", i, h[i]);}*/while (m--) {scanf("%d%d", &a, &b);a = getIndex(a);b = getIndex(b);if (a < 0 || b < 0 || a == b) {puts("0");continue;}int lcab = lca(a, b);if (lcab == a)puts("1");else if (lcab == b)puts("2");elseputs("0");}
}int main() {solve();return 0;
}/*
3
2 -1
1 2
3 1
2
1 2
2 32
1 -1
1 2
2
1 2
2 112
8 -1
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 1210
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
*/
练习二
链接:https://loj.ac/p/2610
算法思路:先用最小生成算法把所有大的边加入到树中。
利用倍增算法建立祖先关系,以及到祖先链路上的最小负载。
查询A,B最小负载为=min(A到公共祖先最小负载,B到公共祖先最小负载)。
具体实现分别从A,B查找最近公共祖先时记录链路上的最小值。
注意点:需要事先判断是否可达。题目中规定A!=B。
利用并查集点击前往判断是否在一棵树中。
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>
#include<vector>using namespace std;class UnionFindSet {
private:vector<int> father; // 父结点定义,father[i]=i时,i为本集合的代表vector<int> height; // 代表树高度,初始为1int nodeNum; // 集合中的点数public:UnionFindSet(int n); // 初始化bool Union(int x, int y); // 合并int Find(int x);bool UnionV2(int x, int y); // 合并int FindV2(int x);
};UnionFindSet::UnionFindSet(int n) : nodeNum(n + 1) {father = vector<int>(nodeNum);height = vector<int>(nodeNum);for (int i = 0; i < nodeNum; ++i) father[i] = i, height[i] = 1; // 初始为自己
}int UnionFindSet::Find(int x) {while (father[x] != x) x = father[x];return x;
}bool UnionFindSet::Union(int x, int y) {x = Find(x);y = Find(y);if (x == y)return false;father[x] = y;return true;
}int UnionFindSet::FindV2(int x) {int root = x; // 保存好路径上的头结点while (father[root] != root) {root = father[root];}/*从头结点开始一直往根上遍历把所有结点的father直接指向root。*/while (father[x] != x) {// 一定要先保存好下一个结点,下一步是要对father[x]进行赋值int temp = father[x];father[x] = root;x = temp;}return root;
}/*
需要加入height[]属性,初始化为1.
*/
//合并结点
bool UnionFindSet::UnionV2(int x, int y) {x = Find(x);y = Find(y);if (x == y) {return false;}if (height[x] < height[y]) {father[x] = y;}else if (height[x] > height[y]) {father[y] = x;}else {father[x] = y;height[y]++;}return true;
}const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];void initPara(int n)
{len = 0;for (int i = 0; i < n; i++){head[i] = -1;h[i] = -1;}
}void add(int a, int b, int w)
{to[len] = b;weight[len] = w;nextEdge[len] = head[a];head[a] = len++;
}void dfs(int x, int fa, int w)
{if (fa == -1) h[x] = 0;else {h[x] = h[fa] + 1;father[0][x] = fa;dis[0][x] = w;for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {father[t][x] = father[t - 1][father[t - 1][x]];dis[t][x] = min(dis[t - 1][x], dis[t - 1][father[t - 1][x]]);}}int i;for (i = head[x]; i != -1; i = nextEdge[i]){int j = to[i];if (fa == j)continue;dfs(j, x, weight[i]);}
}int lca(int a, int b) {if (h[a] < h[b]) {return lca(b, a);}int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t))a = father[t][a];}if (a == b)return a;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;a = father[t][a];b = father[t][b];gap -= 1 << t;}return father[0][a];
}int optDis(int a, int b) {if (h[a] < h[b]) {return optDis(b, a);}int d = 1e6;int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t)) {d=min(d, dis[t][a]);a = father[t][a];}}if (a == b)return d;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;d = min(d,dis[t][a]);d = min(d,dis[t][b]);a = father[t][a];b = father[t][b];gap -= 1 << t;}d = min(d, min(dis[0][a], dis[0][b]));return d;
}bool cmp(vector<int> &a, vector<int> &b) {return a[2] > b[2];
}void solve()
{int n, m;int a, b, w;scanf("%d%d", &n, &m);initPara(n);auto us = UnionFindSet(n);vector<vector<int>> eds;for (int i = 0; i < m; ++i) {scanf("%d%d%d", &a, &b, &w);a--, b--;eds.push_back({a,b,w});}sort(eds.begin(), eds.end(), cmp);for (auto ed : eds) {if (us.UnionV2(ed[0], ed[1])) {add(ed[0], ed[1], ed[2]);add(ed[1], ed[0], ed[2]);}}for (int i = 0; i < n; ++i) {if(h[i]<0)dfs(i, -1, 0);}scanf("%d", &m);while (m--) {scanf("%d%d", &a, &b);a--, b--;if (us.FindV2(a) != us.FindV2(b))puts("-1");else printf("%d\n", optDis(a, b));}
}int main()
{solve();return 0;
}/*
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 711
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 1212 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 111
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12*/
练习三
https://loj.ac/p/10130
算法思路:
利用倍增算法建立祖先关系,以及到祖先链路上的距离。
查询A,B距离=A到公共祖先距离+B到公共祖先距离)。
具体实现与上一题类似。
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>using namespace std;const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];void initPara(int n)
{len = 0;for (int i = 0; i < n; i++){head[i] = -1;}
}void add(int a, int b, int w)
{to[len] = b;weight[len] = w;nextEdge[len] = head[a];head[a] = len++;
}void dfs(int x, int fa, int w)
{if (fa == -1) h[x] = 0;else {h[x] = h[fa] + 1;father[0][x] = fa;dis[0][x] = w;for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {father[t][x] = father[t - 1][father[t - 1][x]];dis[t][x] = dis[t - 1][x] + dis[t - 1][father[t - 1][x]];}}int i;for (i = head[x]; i != -1; i = nextEdge[i]){int j = to[i];if (fa == j)continue;dfs(j, x, weight[i]);}
}int lca(int a, int b) {if (h[a] < h[b]) {return lca(b, a);}int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t))a = father[t][a];}if (a == b)return a;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;a = father[t][a];b = father[t][b];gap -= 1 << t;}return father[0][a];
}int optDis(int a, int b) {if (h[a] < h[b]) {return optDis(b, a);}int d = 0;int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t)) {d += dis[t][a];a = father[t][a];}}if (a == b)return d;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;d += dis[t][a];d += dis[t][b];a = father[t][a];b = father[t][b];gap -= 1 << t;}d += dis[0][a] + dis[0][b];return d;
}void solve()
{int n, m;int a, b;scanf("%d", &n);initPara(n);for (int i = 0; i < n - 1; ++i) {scanf("%d%d", &a, &b);a--, b--;add(a, b, 1);add(b, a, 1);}dfs(0, -1, 0);scanf("%d", &m);while (m--) {scanf("%d%d", &a, &b);a--, b--;printf("%d\n", optDis(a, b));}
}int main()
{solve();return 0;
}/*
6
1 2
1 3
2 4
2 5
3 6
2
2 6
5 612
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12*/
练习四
https://acm.hdu.edu.cn/showproblem.php?pid=2586
与练习三类似
#include<stdio.h>
#include<malloc.h>
#include<string.h>
#include<cmath>
#include<algorithm>using namespace std;const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2],weight[M*2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];void initPara(int n)
{len = 0;for (int i = 0; i < n; i++){head[i] = -1;}
}void add(int a, int b, int w)
{to[len] = b;weight[len] = w;nextEdge[len] = head[a];head[a] = len++;
}void dfs(int x, int fa, int w)
{if (fa == -1) h[x] = 0;else {h[x] = h[fa] + 1;father[0][x] = fa;dis[0][x] = w;for (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {father[t][x] = father[t-1][father[t - 1][x]];dis[t][x] = dis[t-1][x]+ dis[t - 1][father[t - 1][x]];}}int i;for (i = head[x]; i != -1; i = nextEdge[i]){int j = to[i];if (fa==j)continue;dfs(j, x, weight[i]);}
}int lca(int a, int b) {if (h[a] < h[b]) {return lca(b, a);}int gap = h[a] - h[b];for (int t = bitL-1; t>=0; t--) {if (gap & (1 << t))a = father[t][a];}if (a == b)return a;gap = h[a];for (int t = bitL-1; t >= 0; t--) {if (gap <=(1 << t))continue;if (father[t][a] == father[t][b])continue;a = father[t][a];b = father[t][b];gap -= 1 << t;}return father[0][a];
}int optDis(int a, int b) {if (h[a] < h[b]) {return optDis(b, a);}int d = 0;int gap = h[a] - h[b];for (int t = bitL - 1; t >= 0; t--) {if (gap & (1 << t)) {d += dis[t][a];a = father[t][a];}}if (a == b)return d;gap = h[a];for (int t = bitL - 1; t >= 0; t--) {if (gap <= (1 << t))continue;if (father[t][a] == father[t][b])continue;d += dis[t][a];d += dis[t][b];a = father[t][a];b = father[t][b];gap -= 1 << t;}d += dis[0][a] + dis[0][b];return d;
}void solve()
{int t;int n, m;int a, b, w;scanf("%d", &t);while (t--) {scanf("%d%d", &n, &m);initPara(n);for (int i = 0; i < n - 1; ++i) {scanf("%d%d%d", &a, &b, &w);a--, b--;add(a, b,w);add(b, a,w);}dfs(0, -1, 0);/*for (int i = 0; i < n; ++i) {printf("%d: %d\n", i, h[i]);}*/while (m--) {scanf("%d%d", &a, &b);a--, b--;printf("%d\n", optDis(a,b));}}
}void test() {int t;int n = 5000, m = 500000;//scanf("%d%d%d", &n, &m, &s);initPara(n);int a, b;for (int i = 0; i < n - 1; ++i) {a = i, b = i + 1;add(a, b,1);add(b, a,1);}dfs(0, -1,0);//printf("%d\n", 1 + lca(10, 5000-1));while (m--) {a = (m+102)%n, b =( 3823+m*2)%n;//printf("%d\n", m);if(lca(a, b)!=min(a,b))printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);}
}int main()
{solve();//test();return 0;
}/*
2
3 2
1 2 10
3 1 15
1 2
2 32 2
1 2 100
1 2
2 11
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 71 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 121
12 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 11 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/
本人码农,希望通过自己的分享,让大家更容易学懂计算机知识。创作不易,帮忙点击公众号的链接。