奇偶路径(换根dp)

题目大意:对于正整数 𝑎, 𝑏 定义函数 𝑓(𝑎, 𝑏) 表示将 𝑎 和 𝑏 用二进制表示后二者不同位的个数。现给出一棵 𝑛 个点的树,点有点权 𝑎𝑖 ,点 𝑖 与点 𝑗 的距离 𝑑𝑖𝑠(𝑖,𝑗) 定义为两点间最短路经过的边数。询问共有多少点对满足 𝑓(𝑎𝑖, 𝑎𝑗) 与 𝑑𝑖𝑠(𝑖,𝑗) 的奇偶性不同

首先鸣谢Tethys大佬

大佬说一眼看出是个换根dp

无论从图论还是dp的角度来说,这都是一道不可多得的好题

考虑对于一个节点x,一个能与他组成点对并造成贡献的点必定属于以下两种:

  • 距离x奇数条边,且\(a_i\)\(a_x\)有偶数位不同
  • 距离x偶数条便,且\(a_i\)\(a_x\)有奇数位不同

然后考虑x与其儿子y,y如何继承来自x的价值?

显然:x与y只差着一条边,如果一个点需要经历奇数条边才能到达x,那么此点需要经历偶数条边才能到达y,反之亦然

如果发现\(a_x\)\(a_y\)有奇数位不同,可以转化为:原来与x这个点的权值有偶数位不同的点,现在与y这个点有奇数位不同

如果发现\(a_x\)\(a_y\)有偶数位不同,可以转化为:原来与x这个点的权值有奇数位不同的点,现在与y这个点有偶数位不同

证明:

\(a_x\)\(a_y\)有奇数位不同的情况下,设另一个点为t,且\(a_t\)\(a_x\)有偶数位不同,偶数位?我们设为0

显然此时\(a_y\)\(a_t\)有奇数位不同,反之依然

这种类比化一的思想太niub了

特判除了1的情况,别的情况就都可以继承啦!

最后除以2的原因是因为我们的答案计算了两遍,设想在一个点\(l\)的时候,肯定计算了另一个点\(r\)所带来的贡献,在\(r\)的时候,又再次计算了\(l\)的贡献,而本质上\(f(a,b)=f(b,a)\)所以理所应当的除以2

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e5 + 66;

inline int read()
{
int s(0), w(1);
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * w;
}

inline void put(int x)
{
if (! x) putchar('0');
if (x < 0) putchar('-'), x = -x;
int num(0); char c[66];
while (x) c[++ num] = x % 10 + 48, x /= 10;
while (num) putchar(c[num --]);
return (void)(putchar('\n'));
}

int ver[N << 1], nex[N << 1], head[N << 1], cnt;

inline void add_edge(int x, int y)
{
ver[++ cnt] = y;
nex[cnt] = head[x];
return (void)(head[x] = cnt);
}

int n, res;
int a[N], f[N], h[N], dep[N], fa[N];

inline int fuck(int x, int y)
{
int ret = 0;
while (x || y)
{
if ((x & 1) != (y & 1)) ++ ret;
x >>= 1, y >>= 1;
}
return ret;
}

inline void build(int x, int yhm_fa)
{
if (x != 1)
{
if ((dep[x] & 1) != (fuck(a[1], a[x]) & 1))
++ h[1];
else ++ f[1];
}
int i, y;
for (i = head[x]; i; i = nex[i])
{
y = ver[i];
if (y == yhm_fa) continue;
fa[y] = x;
dep[y] = dep[x] + 1;
build(y, x);
}
return;
}

inline void dfs(int x)
{
int i, y;
if (x == 1)
{
res += h[1];
for (i = head[x]; i; i = nex[i])
{
y = ver[i];
dfs(y);
}
}
else
{
if (fuck(a[fa[x]], a[x]) & 1)
{
h[x] = h[fa[x]];
f[x] = f[fa[x]];
}
else
{
h[x] = f[fa[x]];
f[x] = h[fa[x]];
}
res += h[x];
for (i = head[x]; i; i = nex[i])
{
y = ver[i];
if (y == fa[x]) continue;
dfs(y);
}
}
return;
}

signed main()
{
int i, x, y;

n = read();
for (i = 1; i < n; ++ i)
{
x = read(), y = read();
add_edge(x, y), add_edge(y, x);
}
for (i = 1; i <= n; ++ i) a[i] = read();

f[1] = 1;
build(1, 0), dfs(1);
put(res / 2);
return 0;
}