mayoko’s diary

プロコンとかいろいろ。

Educational Codeforces Round 3 E. Minimum spanning tree for each edge

解法

まず普通に MST を作ります。で, その木に辺 (u, v) を追加したと考えると, 閉路が出来ます。閉路のいずれかの辺を取り除くと再び木になるので, 閉路の中からコストが最大の辺を取り除けば良いです。

最大の辺はどうやって取り除けば良いかというと, ダブリングのテクニックを使います。maxCost[k][v] = (頂点 v から v より 2^k 個だけ上にある頂点までの辺の最大コスト) というのを保存しておけば, lca を求めるのと同様の手法で頂点の組 (u, v) における, u -> lca -> v の間の辺の最大コストが O(log N) で求められます。

下のコードは蟻本の p293 に書いてある lca のコードをそのまま使った後 maxCost を付け加えた感じで書きました。

const int MAXN = 200200;

struct UnionFind {
    vector<int> par;
    int n, cnt;
    UnionFind(const int& x = 0) {init(x);}
    void init(const int& x) {par.assign(cnt=n=x, -1);}
    inline int find(const int& x) {return par[x] < 0 ? x : par[x] = find(par[x]);}
    inline bool same(const int& x, const int& y) {return find(x) == find(y);}
    inline bool unite(int x, int y) {
        if ((x = find(x)) == (y = find(y))) return false;
        --cnt;
        if (par[x] > par[y]) swap(x, y);
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    inline int count() const {return cnt;}
    inline int count(int x) {return -par[find(x)];}
};

struct edge {
    int cost;
    int u, v;
    int id;
    edge() {}
    edge(int cost, int u, int v, int id) : cost(cost), u(u), v(v), id(id) {}
    bool operator<(const edge& rhs) const {return cost < rhs.cost;}
};
vector<edge> G;

vector<edge> T[MAXN];
int parent[20][MAXN], depth[MAXN], maxCost[20][MAXN];

void dfs(int v, int p, int d) {
    parent[0][v] = p;
    depth[v] = d;
    for (edge e : T[v]) {
        if (e.v != p) {
            maxCost[0][e.v] = e.cost;
            dfs(e.v, v, d+1);
        }
    }
}

void init(int V) {
    dfs(0, -1, 0);
    for (int k = 0; k < 20-1; k++) {
        for (int v = 0; v < V; v++) {
            if (parent[k][v] < 0) {
                parent[k+1][v] = -1;
                maxCost[k+1][v] = -1;
            } else {
                parent[k+1][v] = parent[k][parent[k][v]];
                maxCost[k+1][v] = max(maxCost[k][v], maxCost[k][parent[k][v]]);
            }
        }
    }
}

pii lca(int u, int v) {
    if (depth[u] > depth[v]) swap(u, v);
    pii ret;
    for (int k = 0; k < 20; k++) {
        if ((depth[v]-depth[u])>>k&1) {
            ret.second = max(ret.second, maxCost[k][v]);
            v = parent[k][v];
        }
    }
    if (u==v) {
        ret.first = u;
        return ret;
    }
    for (int k = 19; k >= 0; k--) {
        if (parent[k][u] != parent[k][v]) {
            ret.second = max(ret.second, maxCost[k][u]);
            ret.second = max(ret.second, maxCost[k][v]);
            u = parent[k][u];
            v = parent[k][v];
        }
    }
    ret.second = max(ret.second, max(maxCost[0][u], maxCost[0][v]));
    ret.first = parent[0][u];
    return ret;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < m; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        u--; v--;
        G.emplace_back(w, u, v, i);
    }
    sort(G.begin(), G.end());
    UnionFind uf(n);
    ll sum = 0;
    for (edge e : G) {
        if (uf.same(e.u, e.v)) continue;
        sum += e.cost;
        uf.unite(e.u, e.v);
        T[e.u].emplace_back(e.cost, e.u, e.v, e.id);
        T[e.v].emplace_back(e.cost, e.v, e.u, e.id);
    }
    init(n);
    vector<ll> ans(m);
    for (edge e : G) {
        ans[e.id] = sum + e.cost - lca(e.u, e.v).second;
    }
    for (int i = 0; i < m; i++) cout << ans[i] << endl;
    return 0;
}