树点编号

因为需要保证任意的删除顺序都保证图联通,所以白色黑色分别只能构成一个联通块。

我们直接在树上 dfs 一遍找出子树大小为 aa 或者 bb 的一棵子树,将其染成相应的颜色,剩下的部分染成另一个颜色即可,因为每次删除编号绝对值最小的,所以按后序遍历分配编号即可。

#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
struct Rhine_Lab {
Rhine_Lab()
{
freopen("tom.in", "r", stdin);
freopen("tom.out", "w", stdout);
}
};
Rhine_Lab Ptilopsis_w;

const int N = 1e5+10;

int n, a, b, rt;
int sub, type;
int cnta, cntb;
int siz[N], num[N];
vector<int> ver[N];

void dfs(int x, int fa);
void paint(int x, int fa, int type);

int main()
{
cin >> n >> a >> b;
for(int i = 1; i < n; i++)
{
int x, y; cin >> x >> y;
ver[x].push_back(y);
ver[y].push_back(x);
}


dfs(1, 0);

if(!rt) { cout << -1; return 0; }
else if(a == 0)
paint(rt, 0, 0);
else if(b == 0)
paint(rt, 0, 1);
else
{
paint(sub, rt, type);
paint(rt, sub, type^1);
}
for(int i = 1; i <= n; i++)
cout << num[i] << "\n";
}
void dfs(int x, int fa)
{
siz[x] = 1;
for(auto y : ver[x])
{
if(y == fa) continue;
dfs(y, x); siz[x] += siz[y];
if(siz[y] == a) rt = x, sub = y, type = 0;
if(siz[y] == b) rt = x, sub = y, type = 1;
}

if(n-siz[x] == a) rt = x, sub = fa, type = 0;
if(n-siz[x] == b) rt = x, sub = fa, type = 1;
}
void paint(int x, int fa, int type)
{
for(auto y : ver[x])
{
if(y == fa) continue;
paint(y, x, type);
}
num[x] = ++(type?cntb:cnta);
if(type) num[x] = -num[x];
}

树根选取

考虑一棵子树的贡献。

如果这一棵子树能覆盖所有的颜色,那么在这棵子树外找一个最长链即可让这棵子树的深度最大。

如果除去这棵子树的部分能覆盖所有颜色,就在子树内找个最长链即可让子树外的部分深度最大。

判断子树外是否能覆盖所有颜色只需要判断这一棵子树是不是将某个颜色的所有点包含了即可。

最长链可以通过简单的树形 DP 求出。

统计颜色可以用树上差分/树上启发式合并/线段树合并等许多方法。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
using namespace std;
struct Rhine_Lab {
Rhine_Lab()
{
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
}
};
Rhine_Lab Ptilopsis_w;

const int N = 1e6+10;

namespace tr {
struct node {
int ls, rs;
int cnt, sum;
bool all;
} a[N*10];
int root[N], node_cnt, lim;
void add(int &i, int pos, int l = 1, int r = lim);
void merge(int &x, int &y, int l = 1, int r = lim);
int query(int x);
}
namespace tree {
int len1[N], len2[N], falen[N];
void prework1(int x, int fa);
void prework2(int x, int fa);
}
using namespace tr;
using namespace tree;

int n, m, ans;
int tot[N];
int dis[N], col[N];
vector<int> ver[N];

void dfs(int x, int fa);

int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
cin >> col[i], tot[col[i]]++;
for(int i = 1; i < n; i++)
{
int a, b;
cin >> a >> b;
ver[a].push_back(b);
ver[b].push_back(a);
}

tree::prework1(1, 0);
tree::prework2(1, 0);

tr::lim = m;
dfs(1, 0);

cout << ans << "\n";
}

void dfs(int x, int fa)
{
tr::add(root[x], col[x]);
for(auto y : ver[x])
{
if(y == fa) continue;
dfs(y, x);
merge(root[x], root[y]);
}
if(a[root[x]].sum == m)
ans = max(ans, falen[x]+1);
if(!a[root[x]].all)
ans = max(ans, len1[x]+2);
}

namespace tree {
void prework1(int x, int fa)
{
len1[x] = -1e9;
len2[x] = -1e9;
for(auto y : ver[x])
{
if(y == fa) continue;
prework1(y, x);
if(len1[y]+1 > len1[x])
{
len2[x] = len1[x];
len1[x] = len1[y]+1;
}
else if(len1[y]+1 > len2[x])
len2[x] = len1[y]+1;
}
len1[x] = max(len1[x], 0);
}
void prework2(int x, int fa)
{
for(auto y : ver[x])
{
if(y == fa) continue;
if(len1[y]+1 == len1[x])
falen[y] = max(falen[x], len2[x]) + 1;
else
falen[y] = max(falen[x], len1[x]) + 1;
prework2(y, x);
}
}
}

namespace tr {
#define ls(i) a[i].ls
#define rs(i) a[i].rs
#define lmid ((l+r)>>1)
#define rmid ((l+r+2)>>1)
vector<int> bin;
int newnode()
{
int x;
if(bin.size()) x = bin.back(), bin.pop_back();
else x = ++node_cnt;
a[x].all = a[x].cnt = a[x].ls = a[x].rs = a[x].sum = 0;
return x;
}
void pushup(int i)
{
a[i].sum = a[ls(i)].sum + a[rs(i)].sum;
a[i].all = a[ls(i)].all | a[rs(i)].all;
}
void add(int &i, int pos, int l, int r)
{
if(!i) i = newnode();
if(l == r)
{
a[i].cnt++;
a[i].sum |= 1;
a[i].all = (a[i].cnt==tot[l]);
return void();
}
if(pos <= lmid) add(ls(i), pos, l, lmid);
if(pos >= rmid) add(rs(i), pos, rmid, r);
pushup(i);
}
void merge(int &x, int &y, int l, int r)
{
if(!x or !y) return void(x = x|y);
if(l == r)
{
a[x].cnt += a[y].cnt;
a[x].sum |= a[y].sum;
a[x].all = (a[x].cnt == tot[l]);
return void();
}
merge(ls(x), ls(y), l, lmid);
merge(rs(x), rs(y), rmid, r);
bin.push_back(y);
pushup(x);
}
}