SRM 502 div1 hard:TheCowDivOne
topcoderっぽい良い問題だと思います。
解法
無理ゲーにしか見えなかったので解説(http://apps.topcoder.com/wiki/display/tc/SRM+502)を参考にしました。こちらには考える動機的なものも多少書いてあります(それでも思いつかねぇ…ってレベルですが)。
関数F(d, k, a)を以下のように定義します。
F(d, k, a) := すべての要素がユニークであるような数列のうち, mod となるようなものは何通りあるか
ここでは数列は「バラバラである」ということだけ意識していることに注意してください。すなわち,[1, 2, 3]と[2, 3, 1]や[1,2]と[2,1]は別々に数えます。
さて,このように定義すると,F(N, K, 1)/K! というのはこの問題の答えを示していることに気づきます(F(N, K, 1)自体は順番を考慮していないのでK!で割る必要があるんですね)。
ということでこのF(d, k, a)を求めることに全力を使います。これを求めるために,以下のように考えてみます(漸化式を立てることを意識しながら考えています)。
まず最初のk-1個については数列の要素がバラバラであると仮定します。すると,最後のk個目のみどのようにすればよいか考えれば済みます。
例えば最初のk-1個の和がSであるとしましょう。すると,
mod という式は
mod という式に変形できます。
さて,このSは何でも良いのでしょうか?例えばa=2,d=4,S=1とすると,は偶数なのにSは奇数なので絶対に4の倍数になりません。よって,この場合はF(d,k,a)の数え上げで含まれないことになります。これをさらに一般化して考えると,「Sとして許される値は,aとdの最大公約数の倍数である」ということに気づきます。まぁこれはよくある整数問題なのでOKですよね。
ではこのSを作る方法は何通りあるのかを考えてみます。これは
mod (gはaとdの最大公約数)の場合の数を求めるのと同じになります(Sはgの倍数なので)。これは関数F(g, k-1, 1)で求めることが出来ますね。
次にこのそれぞれのSに対して,としてあり得るものが何通りあるのかを考えます。最初にとして0からd-1までの値を考えると[X_k]としてあり得る値はg個あります(これは整数問題的なアレですね)。しかし実際にはとしてあり得る値は0からN-1までなのでそのN/d倍になります(dはNの約数になるのでN/dは割り切れる数です)。
これをまとめると場合の数はF(g, k-1, 1)*g*N/dとなります。しかし,これは求めるべきF(d, k, a)とは一致しません。なぜなら,がと同じ数になる可能性があるからです。ということでこの分マイナスしなければなりませんが,これは簡単でがどれかに一致するということは数列[X_1, ..., X_{k-1}]についてmod dとなるということなので,これはF(d, k-1, a+1)で求められます。以上により,
F(d, k, a) = F(g, k-1, 1)*g*N/d - F(d, k-1, a+1)
となります。
この再帰関数をメモ化しながら求めれば良いです。
以下ソースコード。考察はかなり難しいけどコードはスッキリ。
const ll MOD = 1e9+7; map<vi, ll> mm; int n; ll dfs(int d, int k, int a) { if (k == 0) return 1; vi v(3); v[0] = d, v[1] = k, v[2] = a; if (mm.find(v) != mm.end()) return mm[v]; int g = __gcd(d, a); ll p = dfs(g, k-1, 1); (p *= ((ll)n*g) / d) %= MOD; ll q = (dfs(d, k-1, (a+1)%d)*(k-1)) % MOD; return mm[v] = (p-q+MOD) % MOD; } // extgcd ll extgcd(ll a, ll b, ll& x, ll& y) { ll d = a; if (b != 0) { d = extgcd(b, a % b, y, x); y -= (a / b) * x; } else { x = 1; y = 0; } return d; } // mod_inverse ll mod_inverse(ll a, ll m) { ll x, y; extgcd(a, m, x, y); return (m+x%m) % m; } class TheCowDivOne { public: int find(int N, int K) { n = N; mm.clear(); ll ret = dfs(N, K, 1); for (int i = 1; i <= K; i++) { (ret *= mod_inverse(i, MOD)) %= MOD; } return ret; } };