mayoko’s diary

プロコンとかいろいろ。

SRM 502 div1 hard:TheCowDivOne

topcoderっぽい良い問題だと思います。

解法

無理ゲーにしか見えなかったので解説(http://apps.topcoder.com/wiki/display/tc/SRM+502)を参考にしました。こちらには考える動機的なものも多少書いてあります(それでも思いつかねぇ…ってレベルですが)。

関数F(d, k, a)を以下のように定義します。
F(d, k, a) := すべての要素がユニークであるような数列X_1, X_2, ..., X_kのうち,X_1 + X_2 + ... + X_{k-1} + a X_k = 0 mod dとなるようなものは何通りあるか

ここでは数列X_1, X_2, ..., X_kは「バラバラである」ということだけ意識していることに注意してください。すなわち,[1, 2, 3]と[2, 3, 1]や[1,2]と[2,1]は別々に数えます。

さて,このように定義すると,F(N, K, 1)/K! というのはこの問題の答えを示していることに気づきます(F(N, K, 1)自体は順番を考慮していないのでK!で割る必要があるんですね)。

ということでこのF(d, k, a)を求めることに全力を使います。これを求めるために,以下のように考えてみます(漸化式を立てることを意識しながら考えています)。

まず最初のk-1個については数列の要素がバラバラであると仮定します。すると,最後のk個目のみどのようにすればよいか考えれば済みます。
例えば最初のk-1個の和がSであるとしましょう。すると,
X_1 + X_2 + ... + X_{k-1} + a X_k = 0 mod dという式は
S + a X_k = 0 mod dという式に変形できます。

さて,このSは何でも良いのでしょうか?例えばa=2,d=4,S=1とすると,a X_kは偶数なのにSは奇数なので絶対に4の倍数になりません。よって,この場合はF(d,k,a)の数え上げで含まれないことになります。これをさらに一般化して考えると,「Sとして許される値は,aとdの最大公約数の倍数である」ということに気づきます。まぁこれはよくある整数問題なのでOKですよね。

ではこのSを作る方法は何通りあるのかを考えてみます。これは

X_1 + X_2 + ... + X_{k-1} = 0 mod g(gはaとdの最大公約数)の場合の数を求めるのと同じになります(Sはgの倍数なので)。これは関数F(g, k-1, 1)で求めることが出来ますね。

次にこのそれぞれのSに対して,X_kとしてあり得るものが何通りあるのかを考えます。最初にX_kとして0からd-1までの値を考えると[X_k]としてあり得る値はg個あります(これは整数問題的なアレですね)。しかし実際にはX_kとしてあり得る値は0からN-1までなのでそのN/d倍になります(dはNの約数になるのでN/dは割り切れる数です)。

これをまとめると場合の数はF(g, k-1, 1)*g*N/dとなります。しかし,これは求めるべきF(d, k, a)とは一致しません。なぜなら,X_kX_1, ..., X_{k-1}と同じ数になる可能性があるからです。ということでこの分マイナスしなければなりませんが,これは簡単でX_kがどれかに一致するということは数列[X_1, ..., X_{k-1}]についてX_1+X_2+...+X_{k-2}+(a+1)X_{k-1}=0mod dとなるということなので,これはF(d, k-1, a+1)で求められます。以上により,

F(d, k, a) = F(g, k-1, 1)*g*N/d - F(d, k-1, a+1)

となります。
この再帰関数をメモ化しながら求めれば良いです。

以下ソースコード。考察はかなり難しいけどコードはスッキリ。

const ll MOD = 1e9+7;
map<vi, ll> mm;
int n;

ll dfs(int d, int k, int a) {
    if (k == 0) return 1;
    vi v(3);
    v[0] = d, v[1] = k, v[2] = a;
    if (mm.find(v) != mm.end()) return mm[v];
    int g = __gcd(d, a);
    ll p = dfs(g, k-1, 1);
    (p *= ((ll)n*g) / d) %= MOD;
    ll q = (dfs(d, k-1, (a+1)%d)*(k-1)) % MOD;
    return mm[v] = (p-q+MOD) % MOD;
}

// extgcd
ll extgcd(ll a, ll b, ll& x, ll& y) {
    ll d = a;
    if (b != 0) {
        d = extgcd(b, a % b, y, x);
        y -= (a / b) * x;
    } else {
        x = 1; y = 0;
    }
    return d;
}

// mod_inverse
ll mod_inverse(ll a, ll m) {
    ll x, y;
    extgcd(a, m, x, y);
    return (m+x%m) % m;
}

class TheCowDivOne {
public:
    int find(int N, int K) {
        n = N;
        mm.clear();
        ll ret =  dfs(N, K, 1);
        for (int i = 1; i <= K; i++) {
            (ret *= mod_inverse(i, MOD)) %= MOD;
        }
        return ret;
    }
};