mayoko’s diary

プロコンとかいろいろ。

SRM 656 div2 hard:PermutationCountsDiv2

SRM 656 に参加した時 med で同じ解法を使ったんですが, いい復習になりました。

解法

包除原理を使います。

例えば, N 個のうち, 2 つだけ p[k] < p[k+1] となるポイントがあったとしましょう。
「区切りごとに p[k] < p[k+1] になっているかはとりあえずおいてといて, とりあえず数列が区切りごとに減少しているようにする」というルールを作ると, 以下のような 4 通りが考えられます。

① 1 つ目の区切りまで単調減少, 1 つ目の区切りで減少するか増加するかは問わない, 2 つ目の区切りまで単調減少, 2 つ目の区切りで減少するか増加するかは問わない, そこから最後まで単調減少
② 2 つ目の区切りまで一気に単調減少, 2 つ目の区切りで減少するか増加するかは問わない, そこから最後まで単調減少
③ 1 つ目の区切りまで単調減少, 1 つ目の区切りで減少するか増加するかは問わない, そこから最後まで一気に単調減少
④ 最初から最後まで単調減少

よくわからない感じですが例えば 1 つ目の例は以下のような感じです。
f:id:mayokoex:20150927125548j:plain

このようにすると, ①の場合で, 例えば 1 つ目の区切りで増加すべきなのに減少してしまった場合を②がカバーしてくれて, 2 つ目の区切りで増加すべきなのに減少してしまった場合を③がカバーしてくれて, ②と③で減らしすぎた数を④がカバーしてくれます。

よって, 答えは ①ー②ー③+④ となります。これは包除原理ですね。

で, ここで大事なのはこの包除原理は数列を区切った数の偶奇しか関係ないという点です。例えば, ①の場合は数列を区切った数が 3 つなので+, ②の場合は数列を区切った数が 2 つなのでー, という感じです。

なので, dp[n][flag] = (n 個目の区切りを見ていて, 区切った数の偶奇が flag で表される時の場合の数) とすれば, うまく dp を解くことが出来ます。

const ll MOD = 1e9+7;

// extgcd
ll extgcd(ll a, ll b, ll& x, ll& y) {
    ll d = a;
    if (b != 0) {
        d = extgcd(b, a % b, y, x);
        y -= (a / b) * x;
    } else {
        x = 1; y = 0;
    }
    return d;
}
// mod_inverse
ll mod_inverse(ll a, ll m) {
    ll x, y;
    extgcd(a, m, x, y);
    return (m+x%m) % m;
}

const int MAXN = 222;
ll fact[MAXN], rfact[MAXN];
ll dp[MAXN][2];
int len, N;
vector<int> pos;

ll nCr(int n, int r) {
    return ((fact[n]*rfact[n-r]) % MOD) * rfact[r] % MOD;
}

ll dfs(int n, int flag) {
    if (n == len-1) return flag==(len-1)%2 ? 1 : -1;
    ll& ret = dp[n][flag];
    if (ret >= 0) return ret;
    ret = 0;
    for (int i = n+1; i < len; i++) {
        ret += nCr(N-pos[n], pos[i]-pos[n]) * dfs(i, flag^1) % MOD;
    }
    ret %= MOD;
    (ret += MOD) %= MOD;
    return ret;
}

class PermutationCountsDiv2 {
public:
    int countPermutations(int N, vector <int> pos) {
        // pos の設定
        ::N = N;
        sort(pos.begin(), pos.end());
        ::pos.clear();
        ::pos.push_back(0);
        for (int el : pos) ::pos.push_back(el);
        ::pos.push_back(N);
        len = ::pos.size();
        fact[0] = rfact[0] = 1;
        for (int i = 1; i < MAXN; i++) {
            fact[i] = (fact[i-1] * i)%MOD;
            rfact[i] = mod_inverse(fact[i], MOD);
        }
        memset(dp, -1, sizeof(dp));
        return (int)dfs(0, 0);
    }