mayoko’s diary

プロコンとかいろいろ。

SRM 666 div1 med:SumOverPermutations

うーん, 気づけば簡単だなぁ…

解法

基本的に, dp[n] = (問題文中の Z) というようにして, これを上手いこと部分問題に落としこむ, というように考えます。

例えば, n 個のうち下の図のような真ん中らへんの頂点を最初に選んだとします。
f:id:mayokoex:20150830175552j:plain
この頂点では, その周辺が何も決まっていないので, その頂点のボールは N 通り考えられます。その後, その左右2つの Z がいくらになっているのかを計算します。(少し説明がいい加減なので後でもっときちんと言いますが)すると, 「この真ん中らへんを最初に選んだ」という条件のもとでの f(P) の合計値は,

N * dp[l] * dp[r] *  {}_{n-1}C_{l}

で求めることが出来ます。
まず {}_{n-1}C_{l}ですが, これは左と右をいじる順番を好きに選んで良いので, その分の掛け算です。左側にあるビンでボールを決定する動作を L, 右側にあるビンでボールを決定する動作を R とすると, 例えば
LLLRRRR
という動作がありえますが, この L と R は好きな順番で操作して良いわけです。この L と R の並べ替え方は  {}_{n-1}C_{l} になりますね。

また, 単純に dp[l] * dp[r] なんてやって良いのか, と思われるかもしれませんが, これで OK です。
例えば左側であり得る f(P) それぞれの値が (a, b, c, ..., z) で, 右側でありえる f(P) それぞれの値が (A, B, C, ..., Z) であるとすると, 左側 * 右側 の掛け算をした際にありえる数字というのは,

a*A, a*B, ..., a*Z,
b*A, b*B, ..., b*Z,
...
z*A, z*B, ..., z*Z

となりますが, これはまさしく (a+b+...+z) * (A+B+...+Z) の値と等しいです。よって dp[l] * dp[r] で大丈夫なことがわかります。

はい, まだ問題はあって, 単純に dp[l] とか考えてると, N-1 通りとすべきところを N 通りと判断してしまったりするので, そこに注意が必要です。例えば上の図で, 左側の区間では, 右端のボールの選び方は N-1 通りとなるはずです。また, 右側の区間では, 左端のボールの選び方は N-1 通りとなるはずです。ここらへんを上手く考慮するために, 「すでに左端のボールは埋められているか」「すでに右端のボールは埋められているか」のフラグをつけておきます。このようにすると, 結局状態数 N*4 のそれぞれを計算量 N で埋めることができるので, 時間内に答えを求められます。

得た知見
  • 部分問題に分けるという考えを持っておく
  • 区間の中で大事な情報だけを取っておく(今回は [l, r] のように区間の座標値は必要なくて, 端っこがどうなっているかだけが重要だった)
const int MAXN = 4002;
const ll MOD = 1e9+7;
ll nCr[MAXN][MAXN];
ll dp[MAXN][4];
int N;

// flag: 1 だったら左が埋まってる
// 2 だったら 右が埋まってる
ll dfs(int n, int flag) {
    ll& ret = dp[n][flag];
    if (ret >= 0) return ret;
    ret = 0;
    for (int i = 0; i < n; i++) {
        if (i == 0) {
            if (flag&1) {
                ret += (N-1) * dfs(n-1, flag);
            } else {
                ret += N * dfs(n-1, (flag|1));
            }
            ret %= MOD;
        } else if (i == n-1) {
            if (flag/2) {
                ret += (N-1) * dfs(n-1, flag);
            } else {
                ret += N * dfs(n-1, (flag|2));
            }
            ret %= MOD;
        } else {
            ll tmp = (N*nCr[n-1][i]) % MOD;
            (tmp *= dfs(i, (flag|2))) %= MOD;
            (tmp *= dfs(n-1-i, (flag|1))) %= MOD;
            (ret += tmp) %= MOD;
        }
    }
    return ret;
}

class SumOverPermutations {
public:
    int findSum(int n) {
        N = n;
        for (int i = 0; i <= n; i++) {
            nCr[i][0] = 1;
            for (int j = 1; j <= i; j++) {
                nCr[i][j] = (nCr[i-1][j-1] + nCr[i-1][j]) % MOD;
            }
        }
        memset(dp, -1, sizeof(dp));
        dp[1][0] = N;
        dp[1][1] = dp[1][2] = N-1;
        dp[1][3] = N-2;
        return (int)(dfs(n, 0));
    }
};