树链剖分是一种进行树上操作的在线算法,剖分的方式分为重链剖分和长链剖分。本文介绍重链剖分,其通过利用树上dfs时间戳的连续性,以及轻重链概念的引入,可以将一棵树转化为一段连续序列,一条路径转化为不超过 logn 段连续区间,结合线段树区间维护,可以较为快速的树上查询/修改。
关于线段树:
重儿子:父节点所有儿子中,所在子树结点数目最多的结点。
重边:父节点和重儿子连成的边。
重链:由多条重边连接而成的路径。
如下图,红色边代表重边,黄色结点代表重儿子,红色边连接而成的路径就是重链
结论1、2不再证明,只证明结论3:
证明比较巧妙(似乎所有算法都是这样,雾)
我们只需证明根节点到任意叶子结点的路径上最多不超过logn条重链
我们考虑当前结点 cur ,父节点fa,son[fa]为fa的重儿子,size[u]代表结点u为根子树的大小
那么我们从叶子结点往上走,会经历若干重链,若干轻边,我们下面证明经过一条轻边后,size[cur]至少变为原来2倍
得证,故从叶子结点向上最多走logn条轻边,而因为经过一条轻边一定会先经过一条重链,故重链也最多logn条,故得证,故结论3得证
f[u]: u的父节点
dep[u]: u的深度
son[u]: u的重儿子
sz[u]: u所在子树大小
top[u]: u所在重链的顶点
dfs预处理:O(n+m)
void dfs1(int x, int father){
//深度、父节点、初始化sz
dep[x] = dep[father] + 1, fa[x] = father, sz[x] = 1;
for(int i = head[x]; ~i; i = edges[i].nxt){
int y = edges[i].v;
if(y == father) continue;
dfs1(y, x);
sz[x] += sz[y]; //累加sz
if(sz[son[x]] < sz[y]) son[x] = y; //维护重儿子
}
}
void dfs2(int x, int t){
top[x] = t;
if(!son[x]) return;
dfs2(son[x], t); //重儿子和父亲的top相同
for(int i = head[x]; ~i; i = edges[i].nxt){
int y = edges[i].v;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y); //轻儿子自己就是所在重链的顶点
}
}
重链剖分一个比较经典的应用就是求lca,效率比倍增要快些,不过二者代码都挺好写,当然tarjan也不错,不过个人感觉tarjan那个思想容易忘。
对于查询(x, y),二者到lca的路径上都满足不超过logn条重链,我们由于预处理了top和fa,我们可以通过x = fa[top[x]]跳到上一条重链
那么x和y各自向上跳不超过logn次就能到达同一条重链了,而且必然满足最终结果为x或y处于lca的位置,这个比较简单,可以自己想一下
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 5e5 + 10, M = 1e6 + 10;
int n, m, root, head[N], idx;
int sz[N], fa[N], son[N], top[N], dep[N];
struct edge{
int v, nxt;
}edges[M];
void addedge(int u, int v){
edges[idx] = { v, head[u] }, head[u] = idx++;
}
void add(int u, int v){
addedge(u, v), addedge(v, u);
}
void dfs1(int x, int father){
dep[x] = dep[father] + 1, fa[x] = father, sz[x] = 1;
for(int i = head[x]; ~i; i = edges[i].nxt){
int y = edges[i].v;
if(y == father) continue;
dfs1(y, x);
sz[x] += sz[y];
if(sz[son[x]] < sz[y]) son[x] = y;
}
}
void dfs2(int x, int t){
top[x] = t;
if(!son[x]) return;
dfs2(son[x], t);
for(int i = head[x]; ~i; i = edges[i].nxt){
int y = edges[i].v;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
int lca(int x, int y){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
//此时已经在一条重链上
return dep[x] < dep[y] ? x : y;
}
int main(){
//freopen("in.txt", "r", stdin);
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m >> root;
for(int i = 0, a, b; i < n - 1; i++)
cin >> a >> b, add(a, b);
dfs1(root, 0), dfs2(root, root);
for(int i = 0, a, b; i < m; i++)
cin >> a >> b, cout << lca(a, b) << '\n';
return 0;
}
如果只是用来求lca,那么就杀鸡用牛刀了,重链剖分的精髓其实在于树上查询/修改。
我们尝试在dfs2的同时为每个结点打上时间戳,dfs中时间戳是一个非常重要的概念,我们有一个经典的结论就是:**同一个连通块内结点的时间戳集合是一个连续区间。**这也是tarjan算法的核心,因为我们深搜一定是搜完了当前连通块才会去搜其他连通块,所以同一连通块内的时间戳集合自然是连续的。
那么当我们在dfs2中为结点打上时间戳后,我们多了哪些可用信息?我们不妨记时间戳数组为id[]
基于上面三条,如果树中结点带权的话,我们可以利用线段树和时间戳来进行任意子树,任意路径的权值和的查询/修改
初始化:
任意子树修改
子树u结点区间为 [id[u], id[u] + sz[u] - 1]
直接调用线段树区间修改接口即可
时间复杂度:O(log n)
任意子树查询
子树u结点区间为 [id[u], id[u] + sz[u] - 1]
直接调用线段树区间查询接口即可
时间复杂度:O(log n)
任意路径修改
这样就完成了路径上所有重链上的结点的修改,自然完成了路径修改
时间复杂度:O(log^2 n)
任意路径查询
和路径修改相同,只不过不是进行区间修改而是区间查询
时间复杂度:O(log^2 n)
void dfs1(int u, int father, int dep)
{ // 父子关系以及sz处理
d[u] = dep, fa[u] = father, sz[u] = 1;
for (int i = head[u]; ~i; i = edges[i].nxt)
{
int v = edges[i].v;
if (v == father)
continue;
dfs1(v, u, dep + 1);
sz[u] += sz[v];
if (sz[son[u]] < sz[v])
son[u] = v;
}
}
void dfs2(int u, int t)
{
nw[id[u] = ++tot] = w[u], top[u] = t; //tot用来记录当前时间戳
if (!son[u])
return;
dfs2(son[u], t);
for (int i = head[u]; ~i; i = edges[i].nxt)
{
int v = edges[i].v;
if (v == son[u] || v == fa[u])
continue;
dfs2(v, v);
}
}
//void pushup(int p)
//void pushdown(int p) 标记下传
//void update(int p, int l, int r, int k) 区间修改
//LL query(int p, int l, int r) 区间查询
void build(int p, int l, int r) //递归建树
{
tr[p] = {l, r, nw[l]};
if (l == r)
return;
int mid = (l + r) >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(p);
}
void update_path(int x, int y, int k) //路径修改
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x, y);
update(1, id[top[x]], id[x], k);
x = fa[top[x]];
}
if (d[x] < d[y])
swap(x, y);
update(1, id[y], id[x], k);
}
LL query_path(int x, int y) //路径查询
{
LL res = 0;
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x, y);
res = (res + query(1, id[top[x]], id[x])) % mod;
x = fa[top[x]];
}
if (d[x] < d[y])
swap(x, y);
res = (res + query(1, id[y], id[x])) % mod;
return res;
}
void update_tr(int x, int k) //子树修改
{
update(1, id[x], id[x] + sz[x] - 1, k);
}
LL query_tr(int x) //子树查询
{
return query(1, id[x], id[x] + sz[x] - 1);
}
//main
dfs1(root, -1, 1);
dfs2(root, root);
build(1, 1, n);
板子题,复现上面的算法流程即可
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10, M = N << 1;
int n, m, root, mod;
struct edge
{
int v, nxt;
} edges[M];
int head[N], idx;
int w[N], id[N], nw[N], tot;
int d[N], sz[N], top[N], fa[N], son[N];
struct node
{
int l, r;
LL sum, tag;
} tr[N << 2];
void addedge(int u, int v)
{
edges[idx] = {v, head[u]}, head[u] = idx++;
}
void add(int u, int v)
{
addedge(u, v), addedge(v, u);
}
void dfs1(int u, int father, int dep)
{ // 父子关系以及sz处理
d[u] = dep, fa[u] = father, sz[u] = 1;
for (int i = head[u]; ~i; i = edges[i].nxt)
{
int v = edges[i].v;
if (v == father)
continue;
dfs1(v, u, dep + 1);
sz[u] += sz[v];
if (sz[son[u]] < sz[v])
son[u] = v;
}
}
void dfs2(int u, int t)
{
nw[id[u] = ++tot] = w[u], top[u] = t;
if (!son[u])
return;
dfs2(son[u], t);
for (int i = head[u]; ~i; i = edges[i].nxt)
{
int v = edges[i].v;
if (v == son[u] || v == fa[u])
continue;
dfs2(v, v);
}
}
void pushup(int p)
{
tr[p].sum = (tr[lc].sum + tr[rc].sum) % mod;
}
void build(int p, int l, int r)
{
tr[p] = {l, r, nw[l]};
if (l == r)
return;
int mid = (l + r) >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(p);
}
void pushdown(int p)
{
if (tr[p].tag)
{
tr[lc].sum = (tr[lc].sum + (tr[lc].r - tr[lc].l + 1) * tr[p].tag + mod) % mod;
tr[rc].sum = (tr[rc].sum + (tr[rc].r - tr[rc].l + 1) * tr[p].tag + mod) % mod;
tr[lc].tag += tr[p].tag, tr[rc].tag += tr[p].tag;
tr[p].tag = 0;
}
}
void update(int p, int l, int r, int k)
{
if (l <= tr[p].l && tr[p].r <= r)
{
tr[p].tag += k, tr[p].sum = (tr[p].sum + (tr[p].r - tr[p].l + 1) * k + mod) % mod;
return;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if (l <= mid)
update(lc, l, r, k);
if (r > mid)
update(rc, l, r, k);
pushup(p);
}
LL query(int p, int l, int r)
{
if (l <= tr[p].l && tr[p].r <= r)
{
return tr[p].sum;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) >> 1;
LL res = 0;
if (l <= mid)
res = (res + query(lc, l, r)) % mod;
if (r > mid)
res = (res + query(rc, l, r)) % mod;
return res;
}
void update_path(int x, int y, int k)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x, y);
update(1, id[top[x]], id[x], k);
x = fa[top[x]];
}
if (d[x] < d[y])
swap(x, y);
update(1, id[y], id[x], k);
}
LL query_path(int x, int y)
{
LL res = 0;
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x, y);
res = (res + query(1, id[top[x]], id[x])) % mod;
x = fa[top[x]];
}
if (d[x] < d[y])
swap(x, y);
res = (res + query(1, id[y], id[x])) % mod;
return res;
}
void update_tr(int x, int k)
{
update(1, id[x], id[x] + sz[x] - 1, k);
}
LL query_tr(int x)
{
return query(1, id[x], id[x] + sz[x] - 1);
}
int main()
{
//freopen("in.txt", "r", stdin);
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m >> root >> mod;
for (int i = 1; i <= n; i++)
cin >> w[i], w[i] %= mod;
for (int i = 1, a, b; i < n; i++)
cin >> a >> b, add(a, b);
dfs1(root, -1, 1);
dfs2(root, root);
build(1, 1, n);
for (int i = 0, a, b, c; i < m; i++)
{
cin >> a;
if (a == 1)
{
cin >> a >> b >> c;
update_path(a, b, c);
}
else if (a == 2)
{
cin >> a >> b;
cout << query_path(a, b) << '\n';
}
else if (a == 3)
{
cin >> a >> b;
update_tr(a, b);
}
else
{
cin >> a;
cout << query_tr(a) << '\n';
}
}
/*
1 x y z,表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z
2 x y,表示求树从 x 到 y 结点最短路径上所有节点的值之和。
3 x z,表示将以 x 为根节点的子树内所有节点值都加上 z。
4 x 表示求以 x 为根节点的子树内所有节点值之和
*/
return 0;
}
[P2146
这个题就很板子,而且很舒服,因为fa数组直接给你了
那么题目两个操作的指向性很强:
install x就是从根到x路径全变1
uninstall x就是从根到x路径全变0
这个其实还是重剖板子题,甚至更简单,我们只需要想一下怎么处理线段树的标记即可。
对于线段树的标记,如果为-1,则表示无标记,不需下传,如果为1,则代表结点左右子区间全变1,如果为0,则代表结点左右子区间全变0,这个标记下传自然是可以实现的
对于每次要输出改变多少结点状态,我们先存一下根节点的sum值也就是整个区间的和,然后跟修改后的根结点sum值做差即可
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
#define lc p << 1
#define rc p << 1 | 1
int n, q, head[N], idx, tot;
int id[N], fa[N], d[N], w[N], sz[N], top[N], son[N];
struct edge{
int v, nxt;
}edges[N];
void addedge(int u, int v){
edges[idx] = { v, head[u] }, head[u] = idx++;
}
struct node{
int l, r, sum, tag;
}tr[N << 2];
void build(int p, int l, int r){
tr[p] = { l, r, 0, -1 };
if(l == r) return;
int mid = (tr[p].l + tr[p].r) >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
}
void pushup(int p){
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void pushdown(int p){
if(~tr[p].tag){
tr[lc].sum = (tr[lc].r - tr[lc].l + 1) * tr[p].tag;
tr[rc].sum = (tr[rc].r - tr[rc].l + 1) * tr[p].tag;
tr[lc].tag = tr[rc].tag = tr[p].tag;
tr[p].tag = -1;
}
}
void update(int p, int l, int r, int k){
if(l <= tr[p].l && tr[p].r <= r){
tr[p].tag = k, tr[p].sum = (tr[p].r - tr[p].l + 1) * k;
return;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) >> 1;
if(l <= mid) update(lc, l, r, k);
if(r > mid) update(rc, l, r, k);
pushup(p);
}
int query(int p, int l, int r){
if(l <= tr[p].l && tr[p].r <= r){
return tr[p].sum;
}
pushdown(p);
int mid = (tr[p].l + tr[p].r) >> 1, ret = 0;
if(l <= mid) ret += query(lc, l, r);
if(r > mid) ret += query(rc, l, r);
return ret;
}
void update_path(int x, int y, int k){
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
update(1, id[top[x]], id[x], k);
x = fa[top[x]];
}
if(d[x] < d[y]) swap(x, y);
update(1, id[y], id[x], k);
}
void update_tr(int x, int k){
update(1, id[x], id[x] + sz[x] - 1, k);
}
void dfs1(int u, int dep){
d[u] = dep, sz[u] = 1;
for(int i = head[u]; ~i; i = edges[i].nxt){
int v = edges[i].v;
dfs1(v, dep + 1);
sz[u] += sz[v];
if(sz[son[u]] < sz[v]) son[u] = v;
}
}
void dfs2(int u, int t){
id[u] = ++tot, top[u] = t;
if(!son[u]) return;
dfs2(son[u], t);
for(int i = head[u]; ~i; i = edges[i].nxt){
int v = edges[i].v;
if(v != son[u])
dfs2(v, v);
}
}
int main(){
//freopen("in.txt", "r", stdin);
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n;
for(int i = 2; i <= n; i++) cin >> fa[i], addedge(++fa[i], i);
dfs1(1, 1), dfs2(1, 1), build(1, 1, n);
cin >> q;
string opt;
for(int i = 0, x; i < q; i++){
cin >> opt >> x, ++x;
if(opt[0] == 'i'){
int t = tr[1].sum;
update_path(1, x, 1);
cout << tr[1].sum - t << '\n';
}
else{
int t = tr[1].sum;
update_tr(x, 0);
cout << t - tr[1].sum << '\n';
}
}
return 0;
}
因篇幅问题不能全部显示,请点此查看更多更全内容