树链剖分学习笔记

算法笔记 · 2023-01-14

前言

我们知道线段树可以处理区间修改/查询操作,那么如果这个区间修改/查询变成了树上任意一条路径上点的修改/查询呢?比如我们需要满足一下几个操作 :

  • 给 $u$ 到 $v$ 路径上的每一点都加上 $k$
  • 查询 $u$ 到 $v$ 路径上每一点的和。
    如果数据大的话,暴力算法是不能胜任的。

    树剖

    我们考虑将树上问题转化为区间问题,然后用线段树处理。但是一条 $u$ 到 $v$ 路径上每一点在线段树的位置不一定是连续的,所以我们可能要进行多次区间操作,因此最坏情况下甚至不如朴素算法。所以我们尽量让多的点的编号连续。

概念

现在我们引入几个概念和定义:

  • 时间戳 $dfn_x$:dfs时到达这个点的时间成为这个点的时间戳;
  • 重儿子 $son_x$:所有儿子中子树最大的儿子;
  • 重边:由 $x$ 到 $x$ 的重儿子的一条边;
  • 轻边:由 $x$ 到 $x$ 的非重儿子的边;
  • 重链:若干条连续的重边合称为连续的重链;
  • $top_x$:$x$ 所在重链的顶端(当 $x$ 不在重链上时可以理解为自己到自己的重链 $top_x=x$);
  • f_x:$x$ 的父亲;
  • dep_x:$x$ 的深度。

    核心思想

    对于以上的东西,我们可以通过两编 $dfs$ 来求解:

  • 第一遍:求出深度,时间戳,子树大小,重儿子和每个点的父亲。
  • 第二遍:求出当前节点所在重链的顶端。

    void init1(int x,int fa){
      dep=dep[fa]+1;
      f=fa;
      size=1;
      for(int i=head;i;i=edge[i].next){
          int y=edge[i].to;
          if(y==fa)continue;
          init1(y,x);
          size+=size[y];
          if(size[y]>size[hs])hs=y;
      }
      return;
    }
    void init2(int x,int t){
      top=t;
      dfn=++tot;
      if(hs)init2(hs,t);
      for(int i=head;i;i=edge[i].next){
          int y=edge[i].to;
          if(y!=hs&&y!=f){
              init2(y,y);
          }
      }
      return;
    }

    我们将一个点在区间中的位置设为它的时间戳,所以我们只需要每次优先遍历重儿子就可以保证重链上的点在区间中的连续性。如果我们要对 $u$ 到 $v$ 的路径进行操作,我们就依次将区间 $[x,top_x]$进行修改,令 $x=f_{top_x}$。代码如下:

    while(top!=top[y]){//当x和y在同一重链时退出。
      if(dep[top]<dep[top[y]])swap(x,y);
      modify(1,1,n,dfn[top],dfn,v);//操作
      x=f[top];//操作
    }
    if(dep<dep[y])swap(x,y);
    modify(1,1,n,dfn[y],dfn,v);

    树剖的性质

  • 树上的每个节点都属于且仅属于一条重链;
  • 重链开头的节点不一定是其父亲的重儿子;
  • 一条重链所有点的 $dfn$ 是连续的;
  • 一颗子树的 $dfn$ 是连续的。

    复杂度分析

    可以发现,我们每次向下走一条轻边,子树的大小会减少一半及以上。因此,从根到某一点一定会经过不超过 $\log n$ 条重链,再乘上线段树的复杂度,树剖的复杂度是 $O(n \log^2 n)$,但是事实上,想要构造数据使得树剖达到最坏情况是肥肠困难的。
    __

    例题

    洛谷P3384【模板】重链剖分/树链剖分

    题意

给定一颗 $N$ 个节点的树,每个节点有一个权值,现在要你实现以下操作:

  • 1 x y z,表示将树从 $x$ 到 $y$ 结点最短路径上所有节点的值都加上 $z$。
  • 2 x y,表示求树从 $x$ 到 $y$ 结点最短路径上所有节点的值之和。
  • 3 x z,表示将以 $x$ 为根节点的子树内所有节点值都加上 $z$。
  • 4 x 表示求以 $x$ 为根节点的子树内所有节点值之和

数据范围
对于 $30\%$ 的数据: $1 \leq N \leq 10$,$1 \leq M \leq 10$;

对于 $70\%$ 的数据: $1 \leq N \leq {10}^3$,$1 \leq M \leq {10}^3$;

对于 $100\%$ 的数据: $1\le N \leq {10}^5$,$1\le M \leq {10}^5$,$1\le R\le N$,$1\le P \le 2^{31}-1$。
_

讲解

对于这道题,我们多了对子树的操作。我们发现 $dfn$ 的性质有:一颗子树的 $dfn$ 是连续的,则 $x$ 的子树的 $dfn$ 值就是 $[dfn_x,dfn_x+size_x-1]$ 所以操作子树就可以看做操作区间 $[dfn_x,dfn_x+size_x-1]$。

#include<iostream>
#include<cstring>
#include<algorithm>
#define lson (x<<1)
#define rson (lson)|1
using namespace std;
const int maxn=100010;
int n,m,p,cnt,tot,root,head[maxn];
int f[maxn],hs[maxn],top[maxn],dep[maxn],dfn[maxn],size[maxn];
long long b[maxn];
struct G{
    long long sum,tag;
}a[maxn<<2];
struct E{
    int to,next;
}edge[maxn<<1];
void add(int u,int v){
    edge[++cnt].to=v;edge[cnt].next=head[u];head[u]=cnt;
}
void push_down(int x,int l,int r){
    a[lson].tag=(a[lson].tag+a.tag)%p;
    a[rson].tag=(a[rson].tag+a.tag)%p;
    int mid=(l+r)>>1;
    a[lson].sum=(a[lson].sum+a.tag*(mid-l+1))%p;
    a[rson].sum=(a[rson].sum+a.tag*(r-mid))%p;
    a.tag=0;
    return;
}
void modify(int x,int l,int r,int L,int R,long long v){
    if(l>R||r<L)return;
    if(L<=l&&r<=R){
        a.tag=(a.tag+v)%p;
        a.sum=(a.sum+(r-l+1)*v)%p;
        return;
    }
    int mid=(l+r)>>1;
    push_down(x,l,r);
    modify(lson,l,mid,L,R,v);
    modify(rson,mid+1,r,L,R,v);
    a.sum=(a[lson].sum+a[rson].sum)%p;
    return;
}
long long query(int x,int l,int r,int L,int R){
    if(l>R||r<L)return 0;
    if(L<=l&&r<=R){
        return a.sum%p;
    }
    int mid=(l+r)>>1;
    push_down(x,l,r);
    return (query(lson,l,mid,L,R)+query(rson,mid+1,r,L,R))%p;
}
void init1(int x,int fa){
    dep=dep[fa]+1;
    f=fa;
    size=1;
    for(int i=head;i;i=edge[i].next){
        int y=edge[i].to;
        if(y==fa)continue;
        init1(y,x);
        size+=size[y];
        if(size[y]>size[hs])hs=y;
    }
    return;
}
void init2(int x,int t){
    top=t;
    dfn=++tot;
    if(hs)init2(hs,t);
    for(int i=head;i;i=edge[i].next){
        int y=edge[i].to;
        if(y!=hs&&y!=f){
            init2(y,y);
        }
    }
    return;
}
void add(int x,int y,long long v){
    v%=p;
    while(top!=top[y]){
        if(dep[top]<dep[top[y]])swap(x,y);
        modify(1,1,n,dfn[top],dfn,v);
        x=f[top];
    }
    if(dep<dep[y])swap(x,y);
    modify(1,1,n,dfn[y],dfn,v);
    return;
}
long long get(int x,int y){
    long long ans=0;
    while(top!=top[y]){
        if(dep[top]<dep[top[y]])swap(x,y);
        ans=(ans+query(1,1,n,dfn[top],dfn))%p;
        x=f[top];
    }
    if(dep<dep[y])swap(x,y);
    ans=(ans+query(1,1,n,dfn[y],dfn))%p;
    return ans;
}
int main(){
    ios::sync_with_stdio(false);
    std::cin.tie(0);std::cout.tie(0);
    cin>>n>>m>>root>>p;
    for(int i=1;i<=n;i++)cin>>b[i];
    for(int i=1;i<n;i++){
        int u,v;cin>>u>>v;
        add(u,v);add(v,u);
    }
    init1(root,0);
    init2(root,root);
    for(int i=1;i<=n;i++){
        modify(1,1,n,dfn[i],dfn[i],b[i]);
    }
    while(m--){
        int op,x,y;
        long long v;
        cin>>op;
        if(op==1){
            cin>>x>>y>>v;
            add(x,y,v);
        }else if(op==2){
            cin>>x>>y;
            cout<<get(x,y)<<endl;
        }else if(op==3){
            cin>>x>>v;
            modify(1,1,n,dfn,dfn+size-1,v);
        }else if(op==4){
            cin>>x;
            cout<<query(1,1,n,dfn,dfn+size-1)<<endl;
        }
    }
    return 0;
}
学习笔记
Theme Jasmine by Kent Liao