読者です 読者をやめる 読者になる 読者になる

mayoko’s diary

プロコンとかいろいろ。

yukicoder No.317 辺の追加

yukicoder

いろいろ勉強になりました。良い問題。

解法

よくある計算量の短縮で, 「 \frac{1}{1} + \frac{1}{2} + ... + \frac{1}{n} \log(n) に収束する」というのがあります。今回はこれを利用しましょう。

まず UnionFind を使ってグラフ上の連結成分, およびその連結成分の頂点数を求めます。

で, dp[i][j] = (頂点数が i 以下の連結成分を使った時に, 頂点数 j の連結成分を作るために必要な最小の辺の数+1) とします。
この dp は, 頂点数が i の連結成分が cnt[i] 個あったとすると,

dp[i][j] = min(dp[i-1][j-i], dp[i-1][j-2*i], ..., dp[i-1][j-cnt[i]*i])

となります。しかし, cnt[i] は O(N) になる可能性があるので, この計算を素直にやっていると計算量が O(N^2 log N) (log N は最初に書いた  \log(n)に収束するやつです) になってしまいます。これでは間に合わないので, 少し工夫をしましょう。

cnt[i] は  a_0 2^0 + a_1 2^1 + ... + a_n 2^n という形で書くことが出来ます。なので, 上のようにいちいち 0 <= k <= cnt[i] を満たすすべての k を調べなくても, k = 1, 2, 4, 8, ... の形のものだけを調べれば, 一般の場合にも対応できます。

この工夫により, 計算量が O( N \log^2 N) になるので, 無事計算できます。…なんですが, 個人的にはここでいろいろ勘違いしたので, そのミスも載せておきます。下のコードでは, ans[0] = 0; と書いてあるところから, ans をそれぞれ出力する間に書いてある部分に該当します。

まずひとつ目。

    for (int i = 1; i <= N; i++) {
        for (int k = 20; k >= 0; k--) {
            if ((cnt[i]>>k)&1) {
                int num = 1<<k;
                for (int j = N-i*num; j >= 0; j--) {
                    int nj = j+i*num;
                    ans[nj] = min(ans[nj], ans[j]+num);
                }
            }
        }
    }

例えば cnt[i] = 6 になる場合を考えると, k = 0 で if 文の中身が成り立ちませんが, cnt[i] = 6 なので, 当然 頂点数 i の連結成分を 1 つだけ使いたい, という状況が現れる可能性があります。その状況を無視しているので, このコードは正しくありません。

じゃあこうすればええやんけwと思って次に提出したコードがこちら。

    for (int i = 1; i <= N; i++) {
        bool flag = false;
        for (int k = 20; k >= 0; k--) {
            if (((cnt[i]>>k)&1) || flag) {
                flag = true;
                int num = 1<<k;
                for (int j = N-i*num; j >= 0; j--) {
                    int nj = j+i*num;
                    ans[nj] = min(ans[nj], ans[j]+num);
                }
            }
        }
    }

上のコードでは k = 0 が採用されなかったので, k = 1 以上のところで OK になっていれば k = 0 も使っていいよ, というコードに直しています。ただ, これも cnt[i] = 6 のとき NG で, これだと cnt[i] は 6 までしか使ってはいけないのに 7 使ってしまうコードになっています(k = 2, 1, 0 で頂点数 i の連結成分を使うことにすると, 4+2+1 = 7 個の連結成分を使う)。

ということで, これらのコードは NG です。

struct UnionFind {
    vector<int> par;
    int n, cnt;
    UnionFind(const int& x = 0) {init(x);}
    void init(const int& x) {par.assign(cnt=n=x, -1);}
    inline int find(const int& x) {return par[x] < 0 ? x : par[x] = find(par[x]);}
    inline bool same(const int& x, const int& y) {return find(x) == find(y);}
    inline bool unite(int x, int y) {
        if ((x = find(x)) == (y = find(y))) return false;
        --cnt;
        if (par[x] > par[y]) swap(x, y);
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    inline int count() const {return cnt;}
    inline int count(int x) {return -par[find(x)];}
};

const int INF = 1e8;

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N, M;
    cin >> N >> M;
    UnionFind uf(N);
    for (int i = 0; i < M; i++) {
        int u, v;
        cin >> u >> v;
        u--; v--;
        uf.unite(u, v);
    }
    vector<int> memo(N);
    for (int i = 0; i < N; i++) memo[uf.find(i)]++;
    vector<int> cnt(N+1);
    for (int i = 0; i < N; i++) if (memo[i]) cnt[memo[i]]++;
    vector<int> ans(N+1, INF);
    ans[0] = 0;
    for (int i = 1; i <= N; i++) {
        for (int k = 1; cnt[i]; k <<= 1) {
            int num = min(cnt[i], k);
            for (int j = N-num*i; j >= 0; j--) {
                int nj = j+num*i;
                ans[nj] = min(ans[nj], ans[j]+num);
            }
            cnt[i] -= num;
        }
    }
    for (int i = 1; i <= N; i++) cout << ((ans[i]==INF) ? -1 : ans[i]-1) << endl;
    return 0;
}