题目大意
有一棵有$n$个节点的树,每个节点上有$0/1$枚棋子,每次可以选择两个棋子并移动到它们的路径上的相邻节点(满足路径长度至少为$2$),求把所有棋子移到同一个节点的最小花费(无解输出$-1$)。
$n \leq 2 \times 10 ^ 3$
分析
枚举最后汇聚到的点$\texttt{root}$,并以$\texttt{root}$为根建树
我们可以发现,如果存在一个合法的方案,则必然是每次选择不存在祖先关系的两枚棋子,同时向着他们的$\texttt{lca}$处跳一格,重复若干步,直到所有棋子都在$\texttt{root}$
由此我们联想到一个经典模型:有$\texttt{sum}$个节点被分成了若干个集合,每次要找到不在同一集合的两个节点匹配并抵消。
设$\texttt{max}$为最大的集合的大小,则当$sum - max \geq max$时,刚好可以消去$\lfloor \frac{sum}{2} \rfloor$对节点
否则剩下$2 \times max - sum$个来自最大集合的节点,消去了$sum - max$对
现在我们回到原问题,考虑在$u$处做这个操作,设$f_u$表示在$u$的子树里最多消去了多少对。
我们把所有$u$的子树内的有棋子的节点$v$拆成$dis(v, \; u)$个节点,则我们有如下转移(仍然设$\texttt{sum}$为总结点个数,$\texttt{max}$为最大的集合的大小)
$sum - max \geq max$ ,此时$f_u=\lfloor \frac{sum}{2} \rfloor$
$sum - max < max$,此时需要最大子树$v$内的节点来抵消,此时$f_u=sum-max+ \min (f_v, \lfloor \frac{2 \times max - sum}{2} \rfloor )$
以$\texttt{root}$为根的情况合法当且仅当$f_{root} = \frac{\Sigma_u dis(u, root)}{2}$,同时这也是以$\texttt{root}$为根的答案
对$\texttt{root}=1 - n$重复这个$\texttt{dp}$过程,时间复杂度$\texttt{O(}n^2\texttt{)}$
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| #include <bits/stdc++.h>
#define R register #define ll long long #define sum(a, b, mod) (((a) + (b)) % mod)
const int MaxN = 1e4 + 10; const int inf = 0x3f3f3f3f;
struct edge { int next, to; };
char s[MaxN]; edge e[MaxN]; int n, ans, cnt; int a[MaxN], head[MaxN], dis[MaxN], size[MaxN], f[MaxN];
void init() { for (int i = 1; i <= n; i++) dis[i] = size[i] = f[i] = 0; }
void add_edge(int u, int v) { ++cnt; e[cnt].to = v; e[cnt].next = head[u]; head[u] = cnt; }
inline int read() { int x = 0; char ch = getchar(); while (ch > '9' || ch < '0') ch = getchar(); while (ch <= '9' && ch >= '0') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar(); return x; }
void dfs(int u, int fa) { size[u] = a[u]; int maxp = 0; for (int i = head[u]; i; i = e[i].next) { int v = e[i].to; if (v == fa) continue; dfs(v, u), size[u] += size[v], dis[v] += size[v]; dis[u] += dis[v], maxp = ((dis[maxp] > dis[v]) ? maxp : v); } if(!maxp) return (void) (f[u] = 0); if(dis[u] >= 2 * (dis[maxp])) f[u] = (dis[u] / 2); else f[u] = dis[u] - dis[maxp] + std::min(f[maxp], (2 * dis[maxp] - dis[u]) / 2); }
int main() { ans = inf; n = read(), scanf("%s", s + 1); for (int i = 1; i <= n; i++) a[i] = s[i] - '0'; for (int i = 1; i < n; i++) { int u = read(), v = read(); add_edge(u, v), add_edge(v, u); } for (int i = 1; i <= n; i++) { init(), dfs(i, 0); if(dis[i] & 1) continue; if (f[i] * 2 >= dis[i]) ans = std::min(ans, dis[i] / 2); } printf("%d\n", (ans == inf) ? -1 : ans); return 0; }
|