[模板]LCA倍增树算法

参考

oi-wiki-LCA

定义

最近公共祖先简称LCA(Lowest Common Ancestor),指一棵树下两个节点公共祖先中离根最远的那个。下面用LCA({v1, v2, v3, v4..})表示节点的公共祖先

性质

  1. 若u是v的祖先,
  2. 相反不为两者中的值,那么说明u,v不在同一棵子树中
  3. 前序遍历中,出现在所有S中元素之前,后序遍历中则出现在所有𝑆中元素之后;
  4. 两点集并的最近公共祖先为两点集分别的最近公共祖先的最近公共祖先,即
  5. 必定出现在u到v的最短路径上
  6. d(u,v)代表u到v的距离,h代表节点到根的距离

求法

朴素求法

过程

  1. 将其中一个节点上跳至根并记录经过的节点,第二个节点上跳时第一个经过的已记录节点就是最近公共
  2. 计算出冷个节点的高度差,将深度较大的那个节点上跳至两个节点深度一致,再同时上跳两个节点,当两个节点相遇时即可得出最近公共祖先

时间复杂度

预处理时时间复杂度为,单次查询时间复杂度为

倍增算法

过程

倍增算法由朴素算法的第二种方案优化而来,先对所有节点进行预处理,pa[i][x]表示节点i的第个祖先节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//n代表树的节点数量
//children[i]表示i的子节点
var pa = make([][]int, n)
var branch = make([][]int, n) //记录一条分支中的所有节点
var initPa func(int, int)
//i为节点下标,d为节点深度
initPa = func(i int, d int) {
branch[d] = i
for x := 0; d-(1<<x) >= 0; x++ {
pa[i] = append(pa[i], branch[d-(1<<x)])
}
for _, v := range children[i] {
initPa(v, d+1)
}
}

在统一节点深度时可以先计算出两个节点的高度差h,再遍历为1的二进制位,例如快速跳转到i的第13个祖先节点,13的二进制为1101, 那么就可以将13拆成8+4+1,只用进行i=pa[i][1],i=pa[i][4],i=pa[i][8]三次跳转即可到达

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
//depth表示节点的深度
if depth[i] > depth[j] {
h := depth[i] - depth[j]
m := bits.Len(uint(h))
for k := range m {
if (h >> k & 1) == 1 {
i = pa[i][k]
}
}
} else {
h := depth[j] - depth[i]
m := bits.Len(unit(h))
for k := range m {
if (h >> k & 1) == 1 {
j = pa[j][k]
}
}
}

此时两个节点深度相等,再遍历两个节点的pa,若pa[i][k] != pa[j][k]那么就将i和j分别跳转到pa[i][k]和pa[j][k]继续遍历,若遍历完两个节点的pa完全相等,那么答案即为此时i,j的第一个父节点

1
2
3
4
5
6
7
8
9
10
11
//若i,j在同一子树上,则统一深度后相遇,此时直接得出结果
if i == j {
return i
}
for x := len(pa[i]) - 1; x >= 0; x-- {
if pa[i][x] != pa[j][x] {
i, j = pa[i][x], pa[j][x]
x = len(pa[i])
}
}
return pa[i][0]

时间复杂度

预处理时间复杂度为,但单次查询时间复杂度为

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
//n代表树的节点数量
//children[i]表示i的子节点
var pa = make([][]int, n)
var branch = make([][]int, n) //记录一条分支中的所有节点
var initPa func(int, int)
//i为节点下标,d为节点深度
initPa = func(i int, d int) {
branch[d] = i
for x := 0; d-(1<<x) >= 0; x++ {
pa[i] = append(pa[i], branch[d-(1<<x)])
}
for _, v := range children[i] {
initPa(v, d+1)
}
}
initPa(0, 0)
func lca(i int, j int, children [][]int, depth []int) int {
if depth[i] > depth[j] {
h := depth[i] - depth[j]
m := bits.Len(uint(h))
for k := range m {
if (h >> k & 1) == 1 {
i = pa[i][k]
}
}
} else {
h := depth[j] - depth[i]
m := bits.Len(unit(h))
for k := range m {
if (h >> k & 1) == 1 {
j = pa[j][k]
}
}
}
if i == j {
return i
}
for x := len(pa[i]) - 1; x >= 0; x-- {
if pa[i][x] != pa[j][x] {
i, j = pa[i][x], pa[j][x]
x = len(pa[i])
}
}
return pa[i][0]
}