【模板】树链剖分

题目链接:https://www.luogu.org/problemnew/show/P3384

例题:
【ZJOI2008】树的统计:https://www.luogu.org/problemnew/show/P2590
【HAOI2015】树上操作:https://www.luogu.org/problemnew/show/P3178
【USACO11DEC】Grass Planting:https://www.luogu.org/problemnew/show/P3038

树链剖分是解决树上问题的利器,它利用了线段树和$dfs$序的性质,使其能在$log$级复杂度内解决部分树上修改、查询操作

我们针对这样一棵树来学习树链剖分:

我们首先来思考比较容易的子树修改与查询问题

显然地,如果我们从根开始$dfs$,得到这些节点编号的序列,即$dfs$ 序,节点$x$ 位置记为$id[x]$,同时统计包括该点在内的子树大小记为$siz[x]$,那么这个节点及其子树在$dfs序$所在区间为$[id[x],id[x]+siz[x]-1]$

于是我们可以得到下面这段代码

dfs(x,f)
    id[x]=++cnt
    siz[x]=1
    for(i=x的相邻节点)
        if i≠f
            dfs(i,x)
            siz[x]+=siz[i]
end

cchg(x)
    l←id[x],r←id[x]+siz[x]-1
    chg(l,r)
end

asks(x)
    l←id[x],r←id[x]+siz[x]-1
    return ask(l,r)
end

接下来才是重头戏,我们不仅要解决子树的问题,还要解决链上的问题

于是就有了轻重链剖分

所谓重链,就是连续重边连成的链;所谓重边,就是某子树根节点与其重儿子连成的边;所谓重儿子,就是$siz$ 最大的那一个儿子。

↑这是经过剖分的树,虚线为轻边,实线为重边

我们可以得出寻找重儿子的递归代码:

dfs(x,f)
    siz[x]=1
    hwy=0;
    for(i=x的相邻节点)
        if i≠f
            dfs(i,x)
            siz[x]+=siz[i]
            if hwy<siz[i]
                sn[x]=i
                hwy=siz[i]
end

但是仅仅找出重儿子还不够,我们还要将其利用起来

于是我们需要记录重链

于是我们需要记录链头并保证重链在$dfs$序中连续

于是就有了下面这样的两段dfs

void dfs1(int x,int f,int deep){
    tp[x]=deep;
    siz[x]=1;
    fa[x]=f;
    int hwy=0;
    for(int i=h[x];i;i=a[i].li){
        if(a[i].nx!=f){
            dfs1(a[i].nx,x,deep+1);
            siz[x]+=siz[a[i].nx];
            if(siz[a[i].nx]>hwy){
                sn[x]=a[i].nx;
                hwy=siz[a[i].nx];
            }
        }
    }return;
}

void dfs2(int x,int tpx){
    id[x]=++cnt;
    top[x]=tpx;
    if(!sn[x]) return;
    dfs2(sn[x],tpx);
    for(int i=h[x];i;i=a[i].li){
        if(a[i].nx!=fa[x]&&a[i].nx!=sn[x]){
            dfs2(a[i].nx,a[i].nx);
        }
    }return;
}

dfs1(root,0,1);
dfs2(root,root);

$tp[x]$记录深度,$top[x]$记录链顶

这样对于同一条重链上的修改和查询就迎刃而解了

接下来我们要解决最后一个问题,如果两个节点不在同一条重链上怎么办?

像求$LCA$一样向上跳!

下面给出类似于$LCA$ 求解的跳跃代码

void jump(int x,int y){
    while(top[x]!=top[y]){
        if(tp[top[x]]<tp[top[y]]) swap(x,y);
        x所在重链链上操作
        x=fa[top[x]];
    }if(tp[x]>tp[y]) swap(x,y);
    x,y同一条重链链上操作
    return;
}

至于$dfs$序,我们用线段树维护就好

下面是完整的模板↓

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;

const int MAXN=1<<17;

int n,m,cnt,root,MOD,x,y,np,z,p;
int tp[MAXN],fa[MAXN],siz[MAXN],sn[MAXN],val[MAXN];
int h[MAXN],id[MAXN],ren[MAXN],top[MAXN];
int tree[MAXN<<1],pls[MAXN<<1];
struct rpg{
    int li,nx;
}a[MAXN<<1];

void add(int ls,int nx){
    a[++np]=(rpg){h[ls],nx};
    h[ls]=np;
}

void po(int k,int l,int r){
    if(l==r||pls[k]==0) return;
    int i=k<<1,mid=l+r>>1;
    pls[i]=(pls[i]+pls[k])%MOD;
    pls[i|1]=(pls[i|1]+pls[k])%MOD;
    tree[i]=(tree[i]+pls[k]*(mid-l+1))%MOD;
    tree[i|1]=(tree[i|1]+pls[k]*(r-mid))%MOD;
    pls[k]=0;
}

void cadd(int k,int l,int r,int le,int ri,int x){
    po(k,l,r);
    if(le<=l&&r<=ri){
        pls[k]=x;
        tree[k]=(tree[k]+x*(r-l+1))%MOD;
        return;
    }int i=k<<1,mid=l+r>>1;
    if(le<=mid) cadd(i,l,mid,le,ri,x);
    if(mid<ri) cadd(i|1,mid+1,r,le,ri,x);
    tree[k]=(tree[i]+tree[i|1])%MOD;
}

int ask(int k,int l,int r,int le,int ri){
    po(k,l,r);
    if(le<=l&&r<=ri) return tree[k]%MOD;
    int i=k<<1,mid=l+r>>1;
    int sum=0;
    if(le<=mid) sum=(sum+ask(i,l,mid,le,ri))%MOD;
    if(mid<ri) sum=(sum+ask(i|1,mid+1,r,le,ri))%MOD;
    return sum;
}

void cadd1(int x,int y,int z){
    while(top[x]!=top[y]){
        if(tp[top[x]]<tp[top[y]]) swap(x,y);
        cadd(1,1,n,id[top[x]],id[x],z);
        x=fa[top[x]];
    }if(tp[x]>tp[y]) swap(x,y);
    cadd(1,1,n,id[x],id[y],z);
    return;
}

int ask1(int x,int y){
    long long sum=0;
    while(top[x]!=top[y]){
        if(tp[top[x]]<tp[top[y]]) swap(x,y);
        sum+=ask(1,1,n,id[top[x]],id[x]);
        if(sum>=MOD) sum%=MOD;
        x=fa[top[x]];
    }if(tp[x]>tp[y]) swap(x,y);
    sum+=ask(1,1,n,id[x],id[y]);
    return sum%MOD;
}

void init(){
    scanf("%d%d%d%d",&n,&m,&root,&MOD);
    for(int i=1;i<=n;++i){
        scanf("%d",&val[i]);
    }for(int i=1;i<n;++i){
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }return;
}

void dfs1(int x,int f,int deep){
    tp[x]=deep;
    fa[x]=f;
    siz[x]=1;
    int hwy=0;
    for(int i=h[x];i;i=a[i].li){
        if(a[i].nx!=f){
            dfs1(a[i].nx,x,deep+1);
            siz[x]+=siz[a[i].nx];
            if(siz[a[i].nx]>hwy){
                sn[x]=a[i].nx;
                hwy=siz[a[i].nx];
            }
        }
    }return;
}

void dfs2(int x,int topx){
    id[x]=++cnt;
    ren[cnt]=val[x];
    top[x]=topx;
    if(!sn[x]) return;
    dfs2(sn[x],topx);
    for(int i=h[x];i;i=a[i].li){
        if(a[i].nx!=fa[x]&&a[i].nx!=sn[x]){
            dfs2(a[i].nx,a[i].nx);
        }
    }return;
}

void build(int k,int l,int r){
    if(l==r){
        tree[k]=ren[l];
        return;
    }int i=k<<1,mid=l+r>>1;
    build(i,l,mid);
    build(i|1,mid+1,r);
    tree[k]=(tree[i]+tree[i|1])%MOD;
}

void solve(){
    while(m--){
        scanf("%d",&p);
        if(p==1) scanf("%d%d%d",&x,&y,&z),cadd1(x,y,z%MOD);
        else if(p==2) scanf("%d%d",&x,&y),printf("%d\n",ask1(x,y));
        else if(p==3) scanf("%d%d",&x,&z),cadd(1,1,n,id[x],id[x]+siz[x]-1,z%MOD);
        else scanf("%d",&x),printf("%d\n",ask(1,1,n,id[x],id[x]+siz[x]-1));
    }return;
}

int main(){
    init();
    dfs1(root,0,1);
    dfs2(root,root);
    build(1,1,n);
    solve();
    return 0;
}