算法学习:线段树

时间:2019-07-04
本文章向大家介绍算法学习:线段树,主要包括算法学习:线段树使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

线段树,个人理解,生成一棵二叉树,树上的节点表示区间的答案,因为二叉树的性质天然就将树分成两半,所以可以用每个节点存左半边右半边,然后这样子就可以保证效率。

具体讲解是看这位大大的博客,图解和语言都很详细。

https://www.cnblogs.com/TheRoadToTheGold/p/6254255.html

这个里面树是靠结构体存储了一个真正意义上节点为包含左右区间范围及其答案的节点

然后我的代码是参考了洛谷的一个题解的代码,所以是用询问函数上的范围取代了,但是这两者区别并不是很大,因为p这个节点的数字,通过二进制来看,天然的就可以求取出他所代表的区间。

附代码,对树的解释都在注释上:

  1 #include<cstdio>
  2 #define ll long long
  3 #define MAXN 100010
  4 int a[MAXN];
  5 int ans[MAXN];
  6 int t[MAXN];
  7 int tag[MAXN];
  8 inline int ls(int p){return p<<1;}//找左节点 
  9 inline int rs(int p){return p<<1|1;}//找右节点 
 10 void push_up(int p)//更新操作 
 11 {
 12     ans[p]=ans[ls(p)]+ans[rs(p)];
 13     return;
 14 }
 15 void build(ll p,ll l,ll r)//建树 
 16 {
 17     if(l==r)//如果是底层叶节点 
 18     {        //l==r区间,也就是这个节点本身 
 19         ans[p]=a[l];    //返回叶节点 
 20         return;            //那么就是当前数组保存此节点 
 21     }
 22     ll mid=(l+r)>>1;
 23     build(ls(p),l,mid);//使p的左孩子包含(l---mid)的值 
 24     build(rs(p),mid+1,r);
 25     push_up(p); 
 26     //左右结点的值已经通过递归得到
 27     //可以根据左右确定(更新)自己的值 
 28     return;
 29 }
 30 void f(int p,int l,int r,int k)
 31 {
 32     tag[p]+=k;
 33     //给当前节点加上之前节点的懒标记 
 34     //这样子懒标记不会下传
 35     //但是之后还可以用,不会影响其他的查询 
 36     ans[p]+=(r+1-l)*k;
 37     //更新树上的值
 38     //这样就不会向下扩散 
 39     //保证效率 
 40 }
 41 void push_down(int p,int l,int r)
 42 {
 43     ll mid=(l+r)>>1;
 44     //向下更新 
 45     f(ls(p),l,mid,tag[p]);
 46     f(rs(p),mid+1,r,tag[p]);
 47     tag[p]=0;
 48     //因为这个点的懒标记已经下传
 49     //所以这个点对下面来说
 50     //已经没有了需要更新的值 
 51     //懒标记清0 
 52 } 
 53 //我个人最难理解的一步 
 54 //对区间的更新 
 55 //nr,nl需要查询的期间
 56 //l,r当前节点 
 57 void update(ll nl,ll nr,ll l,ll r,ll p,ll k)
 58 {
 59     if(nl<=l&&r<=nr)
 60     //若当前区间全部在所查询区间
 61     //直接进行更新 
 62         {
 63             ans[p]+=(r+1-l)*k;
 64             tag[p]+=k;
 65             return ;     
 66         }    
 67     push_down(p,l,r); 
 68     //因为需要保证下面两个节点的正确性 
 69     //所以要向下进行更新 
 70     ll mid=(r+l)>>1;
 71     if(nl<=mid)    update(nl,nr,l,mid,ls(p),k);
 72     if(nr>mid)    update(nl,nr,mid+1,r,rs(p),k);
 73     //根据区间对子节点进行更新 
 74     //先将子节点更新完成 
 75     push_up(p);
 76     return;
 77 } 
 78 ll query(ll nl,ll nr,ll l,ll r,ll p)
 79 {
 80     ll res=0;
 81     if(nl<=l&&r<=nr)
 82         {
 83             return ans[p];     
 84         }    
 85     ll mid=(r+l)>>1;
 86     push_down(p,l,r);
 87     if(nl<=mid)    res+=query(nl,nr,l,mid,ls(p));
 88     if(nr>mid)    res+=query(nl,nr,mid+1,r,rs(p));
 89     return    res;
 90 }
 91 int main()
 92 {
 93     int n,m;
 94     scanf("%d%d",&n,&m);
 95     for(int i=1;i<=n;i++)
 96         {
 97             scanf("%d",&a[i]);
 98         }
 99     build(1,1,n);
100     while(m--)
101     {
102         int p;
103         scanf("%d",&p);
104         switch(p)
105         {
106             case 1:
107                 {
108                     int l,r,k;
109                     scanf("%d%d%d",&l,&r,&k);
110                     update(l,r,1,n,1,k);
111                     //可以理解为对第一个节点进行更新
112                     //第一个节点包括的区间就是1~n 
113                     break;                //修改操作 
114                 }
115             case 2:
116                 {
117                     int l,r;
118                     scanf("%d%d",&l,&r);
119                     printf("%lld\n",query(l,r,1,n,1));//输出值的操作 
120                     break;//1,n代表的是第一个节点,代表的是具有从1~n的值的数字 
121                 }
122         }
123     }
124     return 0;
125 }
注释版
  1 #include<cstdio>
  2 #define ll long long
  3 #define MAXN 100010
  4 int a[MAXN];
  5 int ans[MAXN];
  6 int t[MAXN];
  7 int tag[MAXN];
  8 inline int ls(int p){return p<<1;}
  9 inline int rs(int p){return p<<1|1;} 
 10 void push_up(int p)
 11 {
 12     ans[p]=ans[ls(p)]+ans[rs(p)];
 13     return;
 14 }
 15 void build(ll p,ll l,ll r)
 16 {
 17     if(l==r)    
 18     {         
 19         ans[p]=a[l];    
 20         return;             
 21     }
 22     ll mid=(l+r)>>1;
 23     build(ls(p),l,mid); 
 24     build(rs(p),mid+1,r);
 25     push_up(p); 
 26     return;
 27 }
 28 void f(int p,int l,int r,int k)
 29 {
 30     tag[p]+=k;
 31     ans[p]+=(r+1-l)*k;
 32 }
 33 void push_down(int p,int l,int r)
 34 {
 35     ll mid=(l+r)>>1;
 36     f(ls(p),l,mid,tag[p]);
 37     f(rs(p),mid+1,r,tag[p]);
 38     tag[p]=0;
 39 } 
 40 void update(ll nl,ll nr,ll l,ll r,ll p,ll k)
 41 {
 42     push_down(p,l,r); 
 43     if(nl<=l&&r<=nr) 
 44         {
 45             ans[p]+=(r+1-l)*k;
 46             tag[p]+=k;
 47             return ;     
 48         }    
 49     ll mid=(r+l)>>1;
 50     if(nl<=mid)    update(nl,nr,l,mid,ls(p),k);
 51     if(nr>mid)    update(nl,nr,mid+1,r,rs(p),k);
 52     push_up(p);
 53     return;
 54 } 
 55 ll query(ll nl,ll nr,ll l,ll r,ll p)
 56 {
 57     ll res=0;
 58     push_down(p,l,r); 
 59     if(nl<=l&&r<=nr)
 60         {
 61             return ans[p];     
 62         }    
 63     ll mid=(r+l)>>1;
 64     if(nl<=mid)    res+=query(nl,nr,l,mid,ls(p));
 65     if(nr>mid)    res+=query(nl,nr,mid+1,r,rs(p));
 66     return    res;
 67 }
 68 int main()
 69 {
 70     int n,m;
 71     scanf("%d%d",&n,&m);
 72     for(int i=1;i<=n;i++)
 73         {
 74             scanf("%d",&a[i]);
 75         }
 76     build(1,1,n);
 77     while(m--)
 78     {
 79         int p;
 80         scanf("%d",&p);
 81         switch(p)
 82         {
 83             case 1:
 84                 {
 85                     int l,r,k;
 86                     scanf("%d%d%d",&l,&r,&k);
 87                     update(l,r,1,n,1,k);
 88                     break;            
 89                 }
 90             case 2:
 91                 {
 92                     int l,r;
 93                     scanf("%d%d",&l,&r);
 94                     printf("%lld\n",query(l,r,1,n,1));
 95                     break; 
 96                 }
 97         }
 98     }
 99     return 0;
100 }
101 
102  
103 
104 题目练习:
105 
106 [AHOI 2009] 维护序列
107 
108 思路:实际上还是道板子题,只不过加入了乘法的应用,对加法的操作还是和之前一样,但是乘法运算时,还要对加法运算的懒标记进行更新,注意push_down函数的位置,以及各种小问题,(因为各种小问题卡了好久)
109 
110 #include<cstdio>
111 #include<iostream>
112 #define ll long long
113 #define MAXN 100010
114 using namespace std;
115 ll a[MAXN];
116 ll ans[MAXN<<2];
117 ll res;
118 ll tag_m[MAXN<<2],tag_a[MAXN<<2];
119 ll mod,n,m;
120 inline int ls(int p){return p<<1;}
121 inline int rs(int p){return p<<1|1;} 
122 void push_up(int p)
123 {
124     ans[p]=(ans[ls(p)]+ans[rs(p)])%mod;
125     return;
126 }
127 void build(ll p,ll l,ll r)
128 {
129     if(l==r)    
130     {         
131         ans[p]=a[l];    
132         return;             
133     }
134     ll mid=(l+r)>>1;
135     build(ls(p),l,mid); 
136     build(rs(p),mid+1,r);
137     push_up(p); 
138     return;
139 }
140 void push_down(int p,int l,int r,int op)
141 {
142     if(tag_m[p]==1&&tag_a[p]==0)
143         return;
144     ll lp=ls(p),rp=rs(p);
145     if(l!=r)
146         {
147             tag_m[rp]=tag_m[rp]*tag_m[p]%mod;
148             tag_m[lp]=tag_m[lp]*tag_m[p]%mod;
149             tag_a[lp]=((tag_a[lp]*tag_m[p])%mod+tag_a[p])%mod;
150             tag_a[rp]=((tag_a[rp]*tag_m[p])%mod+tag_a[p])%mod;
151         }    
152     ans[p]=(ans[p]*tag_m[p]%mod+tag_a[p]*(r-l+1)%mod)%mod;
153     tag_m[p]=1;tag_a[p]=0;
154 } 
155 void update(ll nl,ll nr,ll l,ll r,ll p,ll k,int op)
156 {    
157     push_down(p,l,r,op);
158     if(nl<=l&&r<=nr) 
159         {
160             switch(op)
161                 {
162                     case 1:tag_m[p]=(tag_m[p]*k)%mod,tag_a[p]=(tag_a[p]*k)%mod;break;
163                     case 2:tag_a[p]=(tag_a[p]+k)%mod;break;
164                 }    
165             return ;     
166         }    
167     ll mid=(r+l)>>1;
168     if(nl<=mid)    update(nl,nr,l,mid,ls(p),k,op);
169     if(nr>mid)    update(nl,nr,mid+1,r,rs(p),k,op);
170     push_down(ls(p),l,mid,op);
171     push_down(rs(p),mid+1,r,op);
172     push_up(p);
173     return;
174 } 
175 ll query(ll nl,ll nr,ll l,ll r,ll p,int op)
176 {
177     ll res=0;
178     push_down(p,l,r,op);
179     if(nl<=l&&r<=nr)
180         {
181             return ans[p];     
182         }    
183     ll mid=(r+l)>>1;    
184     if(nl<=mid)    res+=query(nl,nr,l,mid,ls(p),op);
185     if(nr>mid)    res+=query(nl,nr,mid+1,r,rs(p),op);
186     return    res;
187 }
188 int main()
189 {
190     scanf("%lld%lld",&n,&mod);
191     for(int i=1;i<=n;i++)
192         {
193             scanf("%lld",&a[i]);
194         }
195     for(int i=1;i<=2*n;i++)
196         {
197             tag_m[i]=1;
198         }
199     build(1,1,n);
200     scanf("%lld",&m);
201     for(int i=1;i<=m;i++)    
202         {
203             int op,l,r,k;
204             scanf("%d",&op);
205             switch(op)
206                 {
207                 
208                     case 1:
209                     case 2:scanf("%d%d%d",&l,&r,&k);
210                            update(l,r,1,n,1,k,op);
211                            break;
212                     case 3:scanf("%d%d",&l,&r);
213                           res=query(l,r,1,n,1,op)%mod;
214                            printf("%lld\n",res);
215                           break;
216                 }
217         }
218 }
无注释版

原文地址:https://www.cnblogs.com/rentu/p/11134424.html