mayoko’s diary

プロコンとかいろいろ。

Codeforces Round #301 (Div. 2) E. Infinite Inversions

こういうのあんまり慣れていないので良い練習問題でした。

解法

これによく似た問題として蟻本のBITの説明のところに書いてある問題(p162)があります。これはわかっているという前提で説明します。

2つの数の組(i, j)について,この組が反転する場合として考えられるのは以下の2つの場合です。

①i, jがともに少なくとも1回swapされた数である場合
②i, jのどちらか一方のみが少なくとも1回swapされた数である場合

これらの反点数をそれぞれ別々に数えます。

まず①の方です。これは先ほど言った蟻本の解説の通りやることが出来ます…と言っても僕も結構混乱したのでどうやって数えるのか説明したいと思います。

まず,生データを使っていると配列の数が足りないので座標圧縮する必要があります。
これはどうするかというと,map < int, int > を使って解決しています。M[a]というのは,「今インデックスa上にはなんの数が入っているか」が表されています。例えば,最初はM[a] = aで,aとbの数をスワップするとM[a] = bになりますね。

mapにこういうのをどんどん入れてスワップしていくと結局のところMには「一回でもswapしたことのあるやつのそれぞれのインデックスに今なんの数が入っているか」を表すことになります。

例えば
2
2 6
4 8
という入力が入ったら
M[2] = 6
M[4] = 8
M[6] = 2
M[8] = 4
となりますね。

で,本題に戻りますが蟻本と同じように反点数を求めたいです。蟻本の方で結局のところ反点数を求めるときに大事なのは,「並べられている数列のある数が結局のところ何番目に小さいのか」ということだけです。先ほどの例で言うと
・2は1番目に小さい
・4は2番目に小さい
・6は3番目に小さい
・8は4番目に小さい
という,これが大事です。このように整理してやると,蟻本の解法と同じように解くことが出来ます。今回の問題の場合は,map < int, int > Aという配列を用意して,Aに「Mに出てくる数がそれぞれ何番目に小さいか」ということを記録してもらいます。そうすると,「1回でもswapされる数」はたかだか2n個なのでBITで反点数を求めることが出来ます。
一応Aをさっきの例に照らし合わせるてみると,
A[2] = 1
A[4] = 2
A[6] = 3
A[8] = 4
です。

次に②の方です。こっちは結構簡単です。swapされた数以外で反転している数の個数を求めれば良いだけです。swapされた数でインデックスがa番目の数がbであるとすると,普通に考えればabs(b-a)個の反点数があります。しかしswapされた数は無視しないといけないので,abs(A[b]-A[a])を引き算しなければなりません。

例えばさっきの入力の例で言うと,4番目の数は8なので普通に考えると8-4=4個の反点数がありますが,6,8番目の数はswapされた数のうちに入っているのでこれは無視しなければなりません。よって,数えるべき数は4-2で2になります。

以下ソースコード

int n, m;
map<int, int> M, A;
const int MAXN = 100010;
ll bit[2*MAXN];

ll sum(int k) {
    ll ret = 0;
    while (k) {
        ret += bit[k];
        k -= k&-k;
    }
    return ret;
}

void update(int k, int x) {
    while (k <= m) {
        bit[k] += x;
        k += k&-k;
    }
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    cin >> n;
    for (int i = 0; i < n; i++) {
        int a, b;
        cin >> a >> b;
        if (M.count(a) == 0) M[a] = a;
        if (M.count(b) == 0) M[b] = b;
        swap(M[a], M[b]);
    }
    for (auto el : M) {
        A[el.first] = ++m;
    }
    ll ans = 0;
    for (auto el : M) {
        if (el.second > el.first) {
            ans += el.second - el.first - (A[el.second]-A[el.first]);
        } else {
            ans += el.first - el.second - (A[el.first]-A[el.second]);
        }
        ans += sum(m) - sum(A[el.second]);
        update(A[el.second], 1);
    }
    cout << ans << endl;
    return 0;
}