题解归档 - cf2219D

题解归档 - cf2219D

本文由 cf-code 本地题解库自动归档;公开内容以本地 AC/验证版本为准。

思路

cf2219D - MEX Replacement on Tree

Pattern

For every vertex v, mex1[v] is the current path MEX and mex2[v] is the
smallest missing value strictly greater than mex1[v].

If we operate at vertex u, only vertices in subtree(u) change. Let:

  • a = p[u]
  • m = mex1[u]

If a < m, value a becomes missing on every affected path, so the operation
cannot improve the sum. Only a > m matters.

For a descendant y:

  • if mex1[y] == m, after inserting m at u, the new MEX is
    min(a, mex2[y]);
  • if mex1[y] > m, the only possible change is deleting a, so the new MEX is
    min(a, mex1[y]).

Therefore the gain for u splits into:

  • positive part over descendants with mex1[y] == m:
    min(a, mex2[y]) - m;
  • negative part over descendants with mex1[y] > a:
    a - mex1[y].

Algorithm

Compute tin/tout, mex1, and mex2 by DFS from root while maintaining a set
of values missing on the current root path.

Compute the negative part offline by sorting vertices by mex1 descending and
queries by p[u] descending, using Fenwick trees over Euler positions.

For the positive part, process each equal-mex1 group separately. For a query
threshold a, maintain group vertices with mex2 <= a in Fenwick trees, so an
Euler interval query gives:

sum(min(a, mex2[y])) - m * count.

Answer is base_sum + max(0, best_gain).

Checks

  • python tools/math_reasoning_search.py --problem cf2219D -n 5 - required
    precheck done.
  • Brute force recomputes the full path-MEX sum after every possible operation
    on random small trees.

代码

来源:problems/cf2219D/solution.cpp

#include <bits/stdc++.h>
using namespace std;

using ll = long long;

struct Fenwick {
    int n = 0;
    vector<ll> bit;

    Fenwick() = default;
    explicit Fenwick(int n_) { init(n_); }

    void init(int n_) {
        n = n_;
        bit.assign(n + 1, 0);
    }

    void add(int idx, ll val) {
        for (; idx <= n; idx += idx & -idx) bit[idx] += val;
    }

    ll sum_prefix(int idx) const {
        ll res = 0;
        for (; idx > 0; idx -= idx & -idx) res += bit[idx];
        return res;
    }

    ll range_sum(int l, int r) const {
        if (l > r) return 0;
        return sum_prefix(r) - sum_prefix(l - 1);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    while (T--) {
        int n;
        cin >> n;
        vector<int> p(n + 1);
        for (int i = 1; i <= n; i++) cin >> p[i];

        vector<vector<int>> g(n + 1);
        for (int i = 0; i < n - 1; i++) {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v);
            g[v].push_back(u);
        }

        vector<int> tin(n + 1), tout(n + 1), parent(n + 1), mex1(n + 1), mex2(n + 1);
        set<int> missing;
        for (int x = 0; x <= n; x++) missing.insert(x);

        int timer = 0;
        vector<array<int, 3>> st;
        st.push_back({1, 0, 0});
        while (!st.empty()) {
            auto [u, fa, state] = st.back();
            st.pop_back();
            if (state == 0) {
                parent[u] = fa;
                tin[u] = ++timer;
                missing.erase(p[u]);
                mex1[u] = *missing.begin();
                auto it = missing.upper_bound(mex1[u]);
                mex2[u] = (it == missing.end() ? n : *it);
                st.push_back({u, fa, 1});
                for (int i = (int)g[u].size() - 1; i >= 0; i--) {
                    int v = g[u][i];
                    if (v != fa) st.push_back({v, u, 0});
                }
            } else {
                tout[u] = timer;
                missing.insert(p[u]);
            }
        }

        ll base = 0;
        vector<vector<int>> group(n + 1), at_value(n);
        for (int u = 1; u <= n; u++) {
            base += mex1[u];
            group[mex1[u]].push_back(u);
            at_value[p[u]].push_back(u);
        }

        vector<ll> neg(n + 1, 0), pos(n + 1, 0);
        Fenwick cnt(n), sum(n);

        vector<int> by_mex(n);
        iota(by_mex.begin(), by_mex.end(), 1);
        sort(by_mex.begin(), by_mex.end(), [&](int a, int b) {
            return mex1[a] > mex1[b];
        });
        vector<int> by_value(n);
        iota(by_value.begin(), by_value.end(), 1);
        sort(by_value.begin(), by_value.end(), [&](int a, int b) {
            return p[a] > p[b];
        });

        int ptr = 0;
        for (int u : by_value) {
            int a = p[u];
            while (ptr < n && mex1[by_mex[ptr]] > a) {
                int v = by_mex[ptr++];
                cnt.add(tin[v], 1);
                sum.add(tin[v], mex1[v]);
            }
            ll c = cnt.range_sum(tin[u], tout[u]);
            ll s = sum.range_sum(tin[u], tout[u]);
            neg[u] = 1LL * a * c - s;
        }

        Fenwick le_cnt(n), le_sum(n);
        for (int m = 0; m <= n; m++) {
            auto &nodes = group[m];
            if (nodes.empty()) continue;

            vector<int> tins;
            tins.reserve(nodes.size());
            for (int u : nodes) tins.push_back(tin[u]);
            sort(tins.begin(), tins.end());

            vector<int> points = nodes;
            sort(points.begin(), points.end(), [&](int a, int b) {
                return mex2[a] < mex2[b];
            });

            vector<int> queries;
            for (int u : nodes) {
                if (p[u] > m) queries.push_back(u);
            }
            sort(queries.begin(), queries.end(), [&](int a, int b) {
                return p[a] < p[b];
            });

            vector<int> added;
            int j = 0;
            for (int u : queries) {
                int a = p[u];
                while (j < (int)points.size() && mex2[points[j]] <= a) {
                    int v = points[j++];
                    le_cnt.add(tin[v], 1);
                    le_sum.add(tin[v], mex2[v]);
                    added.push_back(v);
                }

                ll total = upper_bound(tins.begin(), tins.end(), tout[u])
                         - lower_bound(tins.begin(), tins.end(), tin[u]);
                ll small_cnt = le_cnt.range_sum(tin[u], tout[u]);
                ll small_sum = le_sum.range_sum(tin[u], tout[u]);
                pos[u] = small_sum + 1LL * a * (total - small_cnt) - 1LL * m * total;
            }

            for (int v : added) {
                le_cnt.add(tin[v], -1);
                le_sum.add(tin[v], -mex2[v]);
            }
        }

        ll best = 0;
        for (int u = 1; u <= n; u++) {
            if (p[u] > mex1[u]) {
                best = max(best, pos[u] + neg[u]);
            }
        }
        cout << base + best << '\n';
    }
    return 0;
}
~  ~  The   End  ~  ~


 赏 
感谢您的支持,我会继续努力哒!
支付宝收款码
tips
文章二维码 分类标签:归档TypechoAutoUpload
文章标题:题解归档 - cf2219D
文章链接:https://www.fangshaonian.cn/archives/209/
最后编辑:2026 年 6 月 28 日 19:04 By 方少年
许可协议: 署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0)
(*) 3 + 3 =
快来做第一个评论的人吧~