SRM 673 div1 med:BearPermutations
解法
さっきと同じように dp を考えます。
dp[n][s] = (長さ n の数列で, スコアの合計が s であるようなものの数) という dp が自然な気がしますが, 今回の dp では最小要素がどこにあるのか, という情報も必要になります。ということで, dp[n][s][c] = (長さ n の数列で, 最小要素は座標 c にあり, スコアの合計が s であるようなものの数) という dp を考えましょう。
まず普通に考えてみます。
c を中心に左側と右側の区間に分かれますが, 左側の最小要素の座標を x, 右側の最小要素の座標を y とすると, スコアとしては y-x だけ加算されます。ということで, これを素直に数えあげると,
dp[c][a][x] * dp[n-c-1][b][y] (ただし, a, b, x, y は a+b+(y-x) = s を満たす)
の和に, 左側の要素の選び方 をかけたものが答えになります。ただ, これは x, y, a は決定しないと求められない dp なので, 一つの状態を計算するために O(N^2*S) 程度かかります。状態量が O(N^2*S) であることを考えると, これはお通夜です。ということで, 計算量を工夫しなければならないです。
もうひとつの dp を考えます(memo とする)。
目的としては,
dp[n][s][c] を memo[c][a] * memo[n-c-1][b] (a+b = s)
というものの和で書けるようにすることです。
これを考えると, memo[n][s] というのは, (長さ n の数列で, スコアの合計が s であるものの数)ということになりそうです。
少し違うのは, この s の中には最小要素の座標 c との差によるスコアも含まれているということです。
例えば memo[c][a] を考える場合は, その左側区間の最小要素が座標 x にあったとすると, a は, c-x と (x を最小要素の座標とした長さ c の数列のスコア)の合計になります。
注目するのは, memo[c][a] のうち c というのが, もとの dp の要素 c と一致していることです。
つまり, もともと数列の最小要素の座標としか考えていなかったものが, 左側の区間では数列の長さという意味も持っているということです。これを考えると, memo[c][a] というのは, dp[c][s-(c-x)][x] の和になります。
うーん, めっちゃわかりにくい気がしますが許して下さい…
const int MAXN = 101; ll C[MAXN][MAXN]; ll fact[MAXN]; ll dp[MAXN][MAXN][MAXN]; ll memo[MAXN][MAXN]; ll dfs(int, int, int, int); ll calc(int, int, int); ll dfs(int n, int s, int c, int MOD) { ll& ret = dp[n][s][c]; if (ret >= 0) return ret; if (n <= 2) { if (s == 0) return ret = 1; else return ret = 0; } ret = 0; if (c == 0 || c == n-1) { for (int i = 0; i < n-1; i++) ret += dfs(n-1, s, i, MOD); ret %= MOD; return ret; } for (int a = 0; a <= s; a++) { int b = s-a; ret += (calc(c, a, MOD)*calc(n-c-1, b, MOD)) % MOD; } ret %= MOD; (ret *= C[n-1][c]) %= MOD; return ret; } ll calc(int n, int s, int MOD) { ll& ret = memo[n][s]; if (ret >= 0) return ret; ret = 0; for (int x = 0; x < n; x++) { if (s-(n-x) >= 0) ret += dfs(n, s-(n-x), x, MOD); } ret %= MOD; return ret; } class BearPermutations { public: int countPermutations(int N, int S, int MOD) { for (int i = 0; i < MAXN; i++) { C[i][0] = 1; for (int j = 1; j <= i; j++) { C[i][j] = C[i-1][j] + C[i-1][j-1]; C[i][j] %= MOD; } } fact[0] = 1; for (int i = 1; i < MAXN; i++) fact[i] = (fact[i-1]*i)%MOD; memset(dp, -1, sizeof(dp)); memset(memo, -1, sizeof(memo)); ll ans = 0; for (int i = 0; i < N; i++) { for (int j = 0; j <= S; j++) { ans += dfs(N, j, i, MOD); } } ans %= MOD; return ans; } };