【JZOJ3234】阴阳

时间:2019-07-11
本文章向大家介绍【JZOJ3234】阴阳,主要包括【JZOJ3234】阴阳使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

题目大意

给出一棵\(n\)个点的树,每条边的权值是1或0,一条路径合法的条件是:路径上存在一个休息点(不能是起点也不能是终点),使得起点到该点路径上0和1的个数相等,该点到终点的路径上0和1的个数也相等。求合法路径条数。

分析

求满足条件的树上路径条数显然是点分治。

考虑分治中心\(x\),对于两条路径\(x->u\)\(x->v\),路径\(u->v\)合法有三种情况:

  • 休息点在\(x->u\)
  • 休息点在\(x->v\)
  • 休息点在\(u\)

将0视作-1,将1视作1,一条路径0和1的个数相等等价于路径权值和为0,那么我们只需要预处理每个点\(u\),路径\(x->u\)是否有休息点,然后用两个桶统计一下答案就行了。为了便于计算答案,我们一棵一棵子树做,就不用容斥了。

Code

#include <cstdio>
#include <cstring>

typedef long long ll;
const int N = 300007;
int max(int a, int b) { return a > b ? a : b; }

ll ans;
int n;
int sum, p, tot, st[N], to[N << 1], nx[N << 1], len[N << 1], siz[N], maxsiz[N], del[N];
void add(int u, int v, int w) { to[++tot] = v, nx[tot] = st[u], len[tot] = (w == 1 ? -1 : 1), st[u] = tot; }

void getp(int u, int from)
{
    siz[u] = 1, maxsiz[u] = 0;
    for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) getp(to[i], u), siz[u] += siz[to[i]], maxsiz[u] = max(maxsiz[u], siz[to[i]]);
    maxsiz[u] = max(maxsiz[u], sum - siz[u]);
    if (maxsiz[u] < maxsiz[p]) p = u;
}
int cnt, arr[N], dis[N], ok[N], buc[N * 4], b[N * 4], b0[N * 4];
void getdis(int u, int from)
{
    if (b0[dis[u] + N]) ok[u] = 1;
    else ok[u] = 0;
    arr[++cnt] = u, b0[dis[u] + N]++;
    for (int i = st[u]; i; i = nx[i]) if (to[i] != from && !del[to[i]]) dis[to[i]] = dis[u] + len[i], getdis(to[i], u);
    b0[dis[u] + N]--;
}
void solve(int u)
{
    del[u] = 1, dis[u] = 0;
    for (int i = st[u]; i; i = nx[i])
        if (!del[to[i]])
        {
            ll ret = 0;
            cnt = 0, dis[to[i]] = len[i], getdis(to[i], u);
            for (int j = 1; j <= cnt; j++)
            {
                if (ok[arr[j]])
                {
                    ret += b[N - dis[arr[j]]] + buc[N - dis[arr[j]]];
                    if (!dis[arr[j]]) ret++;
                }
                else
                {
                    ret += buc[N - dis[arr[j]]];
                    if (!dis[arr[j]]) ret += b[N];
                }
            }
            for (int j = 1; j <= cnt; j++)
            {
                if (ok[arr[j]]) buc[dis[arr[j]] + N]++;
                else b[dis[arr[j]] + N]++;
            }
            ans += ret;
        }
    dis[u] = 0, getdis(u, 0);
    for (int i = 1; i <= cnt; i++) b[dis[arr[i]] + N] = 0, buc[dis[arr[i]] + N] = 0, ok[arr[i]] = 0;
    for (int i = st[u]; i; i = nx[i]) if (!del[to[i]]) sum = siz[to[i]], p = 0, getp(to[i], 0), solve(p);
}

int main()
{
    scanf("%d", &n);
    for (int i = 1, u, v, w; i < n; i++) scanf("%d%d%d", &u, &v, &w), add(u, v, w), add(v, u, w);
    sum = n, maxsiz[0] = N, getp(1, 0), solve(p);
    printf("%lld\n", ans);
    return 0;
}

$flag 上一页 下一页