浅谈线段树中加与乘标记的下放

时间:2022-05-07
本文章向大家介绍浅谈线段树中加与乘标记的下放,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

假设我们一个节点为[val,mul,add],其中val代表该节点的权值,mul为乘法标记,add为加法标记

那么我们有两种表示方式,

  • 第一种:先加再乘

此时该节点为(val+add)*mul

当再遇到一个[_mul,_add]的标记时,

此时节点为[(val+add)*mul+_add]*_mul

把式子展开并重新化为(val+add')*mul'的形式 (也就是提出mul*_mul这一项)得

(val+add+frac{_add}{mul})*mul*_mul

我们发现这里有个除法,会损失很多精度

因此我们换一个思路

  • 第二种:先乘再加

此时该节点为(val*mul)+add

当再遇到一个[_mul,_add]的标记时,

此时节点为[(val*mul)+add]*_mul+_add

把式子展开并重新化为(val*mul')+add'的形式

val*mul*_mul+add*_mul+_add

我们发现这样不需要除法,因此我们选用第二种

其实线段树标记的下放一般都是这个套路

放一下丑陋的代码

// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ls k<<1
#define rs k<<1|1
#define int long long 
using namespace std;
const int MAXN=1e6+10;
inline int read()
{
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}
int N,M,mod;
struct node
{
    int mul,add,sum,l,r,siz;
}T[MAXN];
void update(int k)
{
    T[k].sum=(T[ls].sum%mod+T[rs].sum%mod)%mod;
}
void ps(int x,int f)
{
    T[x].mul=(T[x].mul%mod*T[f].mul%mod)%mod;
    T[x].add=(T[x].add*T[f].mul)%mod;
    T[x].add=(T[x].add+T[f].add)%mod;
    T[x].sum=(T[x].sum%mod*T[f].mul%mod)%mod;
    T[x].sum=(T[x].sum+T[f].add%mod*T[x].siz)%mod;
}
void pushdown(int k)
{
    if(T[k].add==0&&T[k].mul==1) return ;
    ps(ls,k);
    ps(rs,k);
    T[k].add=0;
    T[k].mul=1;
}
void Build(int k,int ll,int rr)
{
    T[k].l=ll;T[k].r=rr;T[k].siz=rr-ll+1;T[k].mul=1;
    if(ll==rr)
    {
        T[k].sum=read()%mod;
        return ;
    }
    int mid=ll+rr>>1;
    Build(ls,ll,mid);
    Build(rs,mid+1,rr);
    update(k);
}
void IntervalMul(int k,int ll,int rr,int val)
{
    if(ll<=T[k].l&&T[k].r<=rr)
    {
        T[k].sum=(T[k].sum*val)%mod;
        T[k].mul=(T[k].mul*val)%mod;
        T[k].add=(T[k].add*val)%mod;
        return ;
    }
    pushdown(k);
    int mid=T[k].l+T[k].r>>1;
    if(ll<=mid) IntervalMul(ls,ll,rr,val);
    if(rr>mid)  IntervalMul(rs,ll,rr,val);
    update(k);
}
void IntervalAdd(int k,int ll,int rr,int val)
{
    if(ll<=T[k].l&&T[k].r<=rr)
    {
        T[k].sum=(T[k].sum+T[k].siz*val)%mod;
        T[k].add=(T[k].add+val)%mod;
        return ;
    }
    pushdown(k);
    int mid=T[k].l+T[k].r>>1;
    if(ll<=mid) IntervalAdd(ls,ll,rr,val);
    if(rr>mid)  IntervalAdd(rs,ll,rr,val);
    update(k);
}
int IntervalSum(int k,int ll,int rr)
{
    int ans=0;
    if(ll<=T[k].l&&T[k].r<=rr)
    {
        ans=(ans+T[k].sum)%mod;
        return ans;
    }
    pushdown(k);
    int mid=T[k].l+T[k].r>>1;
    if(ll<=mid) ans=(ans+IntervalSum(ls,ll,rr))%mod;
    if(rr>mid)  ans=(ans+IntervalSum(rs,ll,rr))%mod;
    return ans%mod;
}
main()
{
    #ifdef WIN32
    freopen("a.in","r",stdin);
    #endif
    N=read();M=read();mod=read();
    Build(1,1,N);
    while(M--)
    {
        int opt=read();
        if(opt==1)
        {
            int l=read(),r=read(),val=read()%mod;
            IntervalMul(1,l,r,val);
        }
        else if(opt==2)
        {
            int l=read(),r=read(),val=read()%mod;
            IntervalAdd(1,l,r,val);
        }
        else if(opt==3)
        {
            int l=read(),r=read();
            printf("%lldn",IntervalSum(1,l,r)%mod);
        }
    }
    return 0;
}