题解归档 - cf2219D
本文最后由方少年更新于2026 年 6 月 28 日,已超过0天没有更新。如果文章内容或图片资源失效,请留言反馈,将会及时处理,谢谢!
题解归档 - cf2219D
本文由 cf-code 本地题解库自动归档;公开内容以本地 AC/验证版本为准。
- 本地编号:
cf2219D - 本地来源:
problems/cf2219D/idea.md - 题目链接:https://codeforces.com/contest/2219/problem/D
- 原始标题:cf2219D - MEX Replacement on Tree
思路
cf2219D - MEX Replacement on Tree
Pattern
For every vertex v, mex1[v] is the current path MEX and mex2[v] is the
smallest missing value strictly greater than mex1[v].
If we operate at vertex u, only vertices in subtree(u) change. Let:
a = p[u]m = mex1[u]
If a < m, value a becomes missing on every affected path, so the operation
cannot improve the sum. Only a > m matters.
For a descendant y:
- if
mex1[y] == m, after insertingmatu, the new MEX ismin(a, mex2[y]); - if
mex1[y] > m, the only possible change is deletinga, so the new MEX ismin(a, mex1[y]).
Therefore the gain for u splits into:
- positive part over descendants with
mex1[y] == m:min(a, mex2[y]) - m; - negative part over descendants with
mex1[y] > a:a - mex1[y].
Algorithm
Compute tin/tout, mex1, and mex2 by DFS from root while maintaining a set
of values missing on the current root path.
Compute the negative part offline by sorting vertices by mex1 descending and
queries by p[u] descending, using Fenwick trees over Euler positions.
For the positive part, process each equal-mex1 group separately. For a query
threshold a, maintain group vertices with mex2 <= a in Fenwick trees, so an
Euler interval query gives:
sum(min(a, mex2[y])) - m * count.
Answer is base_sum + max(0, best_gain).
Checks
python tools/math_reasoning_search.py --problem cf2219D -n 5- required
precheck done.- Brute force recomputes the full path-MEX sum after every possible operation
on random small trees.
代码
来源:problems/cf2219D/solution.cpp
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
struct Fenwick {
int n = 0;
vector<ll> bit;
Fenwick() = default;
explicit Fenwick(int n_) { init(n_); }
void init(int n_) {
n = n_;
bit.assign(n + 1, 0);
}
void add(int idx, ll val) {
for (; idx <= n; idx += idx & -idx) bit[idx] += val;
}
ll sum_prefix(int idx) const {
ll res = 0;
for (; idx > 0; idx -= idx & -idx) res += bit[idx];
return res;
}
ll range_sum(int l, int r) const {
if (l > r) return 0;
return sum_prefix(r) - sum_prefix(l - 1);
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
cin >> T;
while (T--) {
int n;
cin >> n;
vector<int> p(n + 1);
for (int i = 1; i <= n; i++) cin >> p[i];
vector<vector<int>> g(n + 1);
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
vector<int> tin(n + 1), tout(n + 1), parent(n + 1), mex1(n + 1), mex2(n + 1);
set<int> missing;
for (int x = 0; x <= n; x++) missing.insert(x);
int timer = 0;
vector<array<int, 3>> st;
st.push_back({1, 0, 0});
while (!st.empty()) {
auto [u, fa, state] = st.back();
st.pop_back();
if (state == 0) {
parent[u] = fa;
tin[u] = ++timer;
missing.erase(p[u]);
mex1[u] = *missing.begin();
auto it = missing.upper_bound(mex1[u]);
mex2[u] = (it == missing.end() ? n : *it);
st.push_back({u, fa, 1});
for (int i = (int)g[u].size() - 1; i >= 0; i--) {
int v = g[u][i];
if (v != fa) st.push_back({v, u, 0});
}
} else {
tout[u] = timer;
missing.insert(p[u]);
}
}
ll base = 0;
vector<vector<int>> group(n + 1), at_value(n);
for (int u = 1; u <= n; u++) {
base += mex1[u];
group[mex1[u]].push_back(u);
at_value[p[u]].push_back(u);
}
vector<ll> neg(n + 1, 0), pos(n + 1, 0);
Fenwick cnt(n), sum(n);
vector<int> by_mex(n);
iota(by_mex.begin(), by_mex.end(), 1);
sort(by_mex.begin(), by_mex.end(), [&](int a, int b) {
return mex1[a] > mex1[b];
});
vector<int> by_value(n);
iota(by_value.begin(), by_value.end(), 1);
sort(by_value.begin(), by_value.end(), [&](int a, int b) {
return p[a] > p[b];
});
int ptr = 0;
for (int u : by_value) {
int a = p[u];
while (ptr < n && mex1[by_mex[ptr]] > a) {
int v = by_mex[ptr++];
cnt.add(tin[v], 1);
sum.add(tin[v], mex1[v]);
}
ll c = cnt.range_sum(tin[u], tout[u]);
ll s = sum.range_sum(tin[u], tout[u]);
neg[u] = 1LL * a * c - s;
}
Fenwick le_cnt(n), le_sum(n);
for (int m = 0; m <= n; m++) {
auto &nodes = group[m];
if (nodes.empty()) continue;
vector<int> tins;
tins.reserve(nodes.size());
for (int u : nodes) tins.push_back(tin[u]);
sort(tins.begin(), tins.end());
vector<int> points = nodes;
sort(points.begin(), points.end(), [&](int a, int b) {
return mex2[a] < mex2[b];
});
vector<int> queries;
for (int u : nodes) {
if (p[u] > m) queries.push_back(u);
}
sort(queries.begin(), queries.end(), [&](int a, int b) {
return p[a] < p[b];
});
vector<int> added;
int j = 0;
for (int u : queries) {
int a = p[u];
while (j < (int)points.size() && mex2[points[j]] <= a) {
int v = points[j++];
le_cnt.add(tin[v], 1);
le_sum.add(tin[v], mex2[v]);
added.push_back(v);
}
ll total = upper_bound(tins.begin(), tins.end(), tout[u])
- lower_bound(tins.begin(), tins.end(), tin[u]);
ll small_cnt = le_cnt.range_sum(tin[u], tout[u]);
ll small_sum = le_sum.range_sum(tin[u], tout[u]);
pos[u] = small_sum + 1LL * a * (total - small_cnt) - 1LL * m * total;
}
for (int v : added) {
le_cnt.add(tin[v], -1);
le_sum.add(tin[v], -mex2[v]);
}
}
ll best = 0;
for (int u = 1; u <= n; u++) {
if (p[u] > mex1[u]) {
best = max(best, pos[u] + neg[u]);
}
}
cout << base + best << '\n';
}
return 0;
}
文章标题:题解归档 - cf2219D
文章链接:https://www.fangshaonian.cn/archives/209/
最后编辑:2026 年 6 月 28 日 19:04 By 方少年
许可协议: 署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0)