mayoko’s diary

プロコンとかいろいろ。

AtCoder Regular Contest 043 C - 転倒距離

本番中解法全く思い浮かばなかったのは問題だけど, 朝からちょくちょくデバッグして 15 時にようやく AC する程度に実装力がないので解けなかった気がする。

解法

スライド通りです。

www.slideshare.net

ちょくちょくデバッグしましょう。

const int MAXN = 100010;
int A[MAXN], B[MAXN], rA[MAXN], nB[MAXN], rB[MAXN], ans[MAXN], memo[MAXN];

// 1-index
template<typename T>
class BIT {
public:
    vector<T> bit;
    int n;
    BIT(int N) : n(N) {
        bit.resize(N+1);
    }
    T sum(int i) {
        T s = 0;
        while (i > 0) {
            s += bit[i];
            i  -= i&-i;
        }
        return s;
    }
    void add(int i, int x) {
        while (i <= n) {
            bit[i] += x;
            i += i&-i;
        }
    }
};

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N;
    cin >> N;
    for (int i = 1; i <= N; i++) cin >> A[i];
    for (int i = 1; i <= N; i++) cin >> B[i];
    // A の対応から B の配列を変更
    for (int i = 1; i <= N; i++) rA[A[i]] = i;
//    for (int i = 1; i <= N; i++) cout << rA[i] << "  ";
//    cout << endl;

    for (int i = 1; i <= N; i++) nB[i] = rA[B[i]];
//    for (int i = 1; i <= N; i++) cout << nB[i] << "  ";
//    cout << endl;

    // B の反転数を求める
    for (int i = 1; i <= N; i++) rB[nB[i]] = i;
    BIT<int> bit(N);
    ll rev = 0;
    for (int i = 1; i <= N; i++) {
        rev += i-1-bit.sum(rB[i]);
        bit.add(rB[i], 1);
    }
    //cout << rev << endl;

    // 答えを求める
    if (rev%2) {
        cout << -1 << endl;
        return 0;
    }
    rev /= 2;
    // 自分より左にあってかつ自分よりでかい数がいくつあるかを求める
    BIT<int> bit2(N);
    for (int i = N; i >= 1; i--) {
        memo[i] = bit2.sum(rB[i]);
        bit2.add(rB[i], 1);
    }
//    for (int i = 1; i <= N; i++) {
//        cout << memo[i];
//        if (i < N) cout << " ";
//    }
//    cout << endl;
    int cur = 1;
    while (1) {
        if (memo[cur] < rev) {
            rev -= memo[cur];
            ans[cur] = cur;
        } else {
            int j = cur;
            for (int i = 1; i <= N; i++) {
                if (nB[i] >= cur) ans[j++] = nB[i];
            }
            int index = 0;
            for (; ans[index] != cur; index++);
            while (rev) {
                if (ans[index] < ans[index-1]) rev--;
                swap(ans[index], ans[index-1]);
                index--;
            }
            break;
        }
        cur++;
    }
//    for (int i = 1; i <= N; i++) {
//        cout << ans[i];
//        if (i < N) cout << " ";
//    }
//    cout << endl;
    for (int i = 1; i <= N; i++) {
        cout << A[ans[i]];
        if (i < N) cout << " ";
    }
    cout << endl;
    return 0;
}