hdu 4747 Mex

时间:2019-08-19
本文章向大家介绍hdu 4747 Mex,主要包括hdu 4747 Mex使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

Mex是一组整数的函数,它普遍用于公正的博弈定理。对于非负整数集S,mex(S)被定义为S中未出现的最小非负整数。现在我们的问题是关于序列上的mex函数。

考虑一系列非负整数{ai},我们将mex(L,R)定义为最小的非负整数,它不出现在从aL到aR的连续子序列中,包括端点。现在我们想要计算所有1 <= L <= R <= n的mex(L,R)之和。
 

输入
输入包含最多20个测试用例。
对于每个测试用例,第一行包含一个整数n,表示序列的长度。
下一行包含由空格分隔的n个非整数,表示序列。
(1 <= n <= 200000,0 <= ai <= 10 ^ 9)
输入以n = 0结束。
 

产量
对于每个测试用例,输出一行包含表示答案的整数。
 

样本输入
3
0 1 3
5
1 0 2 0 1
0
 

样本输出

24
暗示

对于第一个测试用例:
mex(1,1)= 1,mex(1,2)= 2,mex(1,3)= 2,mex(2,2)= 0,mex(2,3)= 0,mex(3,3) )= 0。
 1 + 2 + 2 + 0 +0 +0 = 5。

思路:

看到题目和区间有关,于是定左端点,枚举长度手算样例我们发现

左端点一定的时候,本序列递增(不过没什么卵用)

当左端点向左移动一位,从该位到下一次出现该位上的数的位置(若下一次不出现则为n+1)之间的所有mex值为该位上的数;

于是我们就有了线段树的思路:

先暴力算出以1号为左边界的mex值,然后用线段树维护这个数组即可

(话说是真的难调....花了我2个小时,难受难受)

#include <cstdio>
#include <map>
#include <cstring>

#define R register
#define MAXN 200005
#define ll long long
#define inf 2147483647

int n;
int flag=0;
int a[MAXN];
int f[MAXN];
int nxt[MAXN];
struct node
{
    int maxx,tag,l,r;
    ll sum;
}t[MAXN<<2];


inline int read();
inline void turn();
inline void input();
inline int ls(int x);
inline int rs(int x);
inline void down(int k);
inline int max(int x,int y);
inline int find(int x,int va);
inline void build(int x,int l,int r);
inline ll get_sum(int x,int LL,int RR);
inline void change(int x,int LL,int RR,int va);

int main()
{    
    while(1)
    {
        R ll ans=0;
        input();
        if(flag) return 0;
        turn();
        build(1,1,n);
        for(R int i=1;i<=n;i++)
        {
            ans+=t[1].sum;
            if(t[1].maxx>a[i])
            {
                int w=find(1,a[i]),end=nxt[i]-1;
                if(w<=end)
                change(1,w,end,a[i]);
            }
            change(1,i,i,0);
        }
        printf("%lld\n",ans);
    }
    return 0;
}

inline int read()
{
    #define C getchar()
    R int x=0,f=1;char a=C;
    for(;a>'9'||a<'0';a=C)if(a=='-') f=-1;
    for(;a>='0'&&a<='9';a=C) x=(x<<1)+(x<<3)+(a^48);
    x*=f;
    return x;
}

inline void input()
{
    memset(t,0,sizeof(t));
    memset(a,0,sizeof(a));
    memset(f,0,sizeof(f));
    memset(nxt,0,sizeof(nxt));
    std::map<int,int> mp;
    n=read();if(n==0) {flag=1;return;}
    for(R int i=1;i<=n;i++) a[i]=read();
    for(R int i=n;i>=1;i--) 
    {
        if(mp[a[i]]==0) mp[a[i]]=i;
        else nxt[i]=mp[a[i]],mp[a[i]]=i;
    }
    for(R int i=1;i<=n;i++) if(nxt[i]==0) nxt[i]=n+1;
}

inline void turn()
{
    std::map<int,bool> vis;
    R int tot=0;
    for(R int i=1;i<=n;i++)
    {
        vis[a[i]]=1;
        while(vis[tot]) tot++;
        f[i]=tot;
    }
}

inline int ls(int x) {return x<<1;}
inline int rs(int x) {return x<<1|1;}
inline int max(int x,int y) {return x>y?x:y;}

inline void build(int x,int l,int r)
{
    t[x].maxx=-inf;t[x].sum=0;t[x].tag=-1;
    t[x].l=l;t[x].r=r;
    if(l==r)
    {
        t[x].maxx=t[x].sum=f[l];
        return;
    }
    int mid=l+r;mid>>=1;
    build(ls(x),l,mid);
    build(rs(x),mid+1,r);
    t[x].maxx=max(t[ls(x)].maxx,t[rs(x)].maxx);
    t[x].sum=t[ls(x)].sum+t[rs(x)].sum;
}

inline void down(int x)
{
    t[ls(x)].sum=(t[ls(x)].r-t[ls(x)].l+1)*t[x].tag;
    t[rs(x)].sum=(t[rs(x)].r-t[rs(x)].l+1)*t[x].tag;
    t[ls(x)].maxx=t[x].tag;t[rs(x)].maxx=t[x].tag;
    t[ls(x)].tag=t[x].tag;t[rs(x)].tag=t[x].tag;
    t[x].tag=-1;
}

inline void change(int x,int LL,int RR,int va)
{
    #define l t[x].l
    #define r t[x].r
    if(LL<=l&&RR>=r)
    {
        t[x].sum=(r-l+1)*va;
        t[x].maxx=va;
        t[x].tag=va;
        return;
    }
    if(t[x].tag!=-1) down(x);
    int mid=l+r;mid>>=1;
    if(LL<=mid) change(ls(x),LL,RR,va);
    if(RR>mid) change(rs(x),LL,RR,va);
    t[x].maxx=max(t[ls(x)].maxx,t[rs(x)].maxx);
    t[x].sum=t[ls(x)].sum+t[rs(x)].sum;
}

inline int find(int x,int va)
{
    #define l t[x].l
    #define r t[x].r
    if(l==r) return l;
    if(t[x].tag!=-1) down(x);
    int mid=l+r;mid>>=1;
    if(t[ls(x)].maxx>va)
    return find(ls(x),va);
    return find(rs(x),va);
}

inline ll get_sum(int x,int LL,int RR)
{
    #define l t[x].l
    #define r t[x].r
    ll res=0;
    if(LL<=l&&RR>=r)
    {    
        return t[x].sum;
    }
    if(t[x].tag!=-1) down(x);
    int mid=l+r;mid>>=1;
    if(LL<=mid) res+=get_sum(ls(x),LL,RR);
    if(RR>mid) res+=get_sum(rs(x),LL,RR);
    return res;
}

原文地址:https://www.cnblogs.com/000226wrp/p/11379662.html