mayoko’s diary

プロコンとかいろいろ。

SRM 527 div1 hard:P8XCoinChange

解法

ここに書く解法は, topcoder の解説を参考にして書いてます。
http://apps.topcoder.com/wiki/display/tc/SRM+527
こっちのほうが証明とかいろいろ丁寧です。英語面倒ですがこっち読むほうがオススメです。

まず, 問題を言い換えてみます。与えられた問題は以下の問題と同じです。

「あるウサギが 0 からスタートして coins_sum にちょうど着くようにジャンプしながら進もうとしている。ただし, ジャンプして進
むことの出来る距離は values に書かれるいずれかの距離のみで, また, ジャンプして進む距離は非増加でなければならない(例えば, ジャンプして進む距離が 2 -> 2 -> 4 -> 1 とかなるのは (2 進んだ後 4 進んでいるので) NG)。ウサギが coins_sum にピッタリたどり着くようなジャンプの列は何通りあるか。」

ジャンプして進む距離が非増加である, というのが大事です。

で, この問題に一見関係ないようですが, 以下のようなアルゴリズムを考えます。
「ある地点にいる時, coins_sum を超えない最大の values の要素を選び, 進む。」

例えば values = {1, 3, 6} で, coins_sum = 11 の時, 6, 9, 10, 11 と進んでいきます。このようにして進んだ時に止まった座標(さっきの例では 6, 9, 10, 11) は, 実は上記の問題では必ず通る地点になります。ちゃんとした証明が topcoder の解説には書いてありますが, 直感的に明らかでしょう。ジャンプして進む距離が非増加なので, 必ず踏むことになりそうです。これがキーになります。

values が要素を n 個持っているとします。
solve(length) という関数を考えます。これは n*n の行列を返し, その行列の (i, j) 成分は, 「values[i] 以下のジャンプしかしないで, 最後のジャンプが values[j] であり, 合計で length だけ進んでいるような場合の数」を返します。
こんな関数作れるなら solve(coins_sum) を求めればええやんけwと思いますが, この length には条件があって, length として取れるものは, values の要素だけです。

solve 関数のアルゴリズムはひとまず置いておいて, これがわかるとどうやって答えがわかるのかを考えます。
先ほど示したとおり, coins_sum までの道のりでは, 必ず踏む地点があります。しかも, その必ず踏む地点と次の地点の間は, values の要素で表すことが出来ます。
ということで, 例えば coins_sum までに上のアルゴリズムで p 回 values[i] というジャンプを使うのであれば, solve(values[i])^p を計算すると, values[i] のジャンプを使う間のジャンプの仕方の場合の数がわかります。これはよくある行列累乗テクニックなので OK でしょう(solve(A+B) = solve(A)*solve(B) という関係から, 累乗にしても同じになる)。

よって, solve 関数ができれば, 答えが得られることがわかりました。あとは solve 関数のアルゴリズムを考えましょう。これはそれほど難しくありません。
まず solve(values[0]) は, 0 列目がすべて 1 であるような行列になります。定義から明らかです。
また, k >= 1 を満たす k について, solve(values[k]) は, values[k] のジャンプを使って一気に values[k] まで行くか, values[k-1] 以下のジャンプを最初に使って, 小刻みに values[k] までたどり着くかのいずれかです。
前者は, solve(values[k]) の行列を m として, m[j][k] = 1 (j >= k) となります。
後者の場合は, values[k-1] 以下のジャンプのみを用いて values[k-1] 進むのを values[k]/values[k-1] 回やるのと同じになるので, solve(values[k-1])^(values[k]/values[k-1]) を求めれば OK です。

typedef long long number;
typedef vector<number> vec;
typedef vector<vec> matrix;

const ll MOD = 1e6+3;

matrix Zero(int n) {
    matrix A(n, vec(n));
    return A;
}

// O( n )
matrix identity(int n) {
    matrix A(n, vec(n));
    for (int i = 0; i < n; ++i) A[i][i] = 1;
    return A;
}
// O( n^3 )
matrix mul(const matrix A, const matrix B) {
    matrix C(A.size(), vec(B[0].size()));
    for (int i = 0; i < (int)C.size(); ++i)
        for (int j = 0; j < (int)C[i].size(); ++j)
            for (int k = 0; k < (int)A[i].size(); ++k) {
                C[i][j] += A[i][k] * B[k][j];
                C[i][j] %= MOD;
            }
    return C;
}
// O( n^3 log e )
matrix pow(const matrix &A, ll e) {
    if (e == 0) return identity(A.size());
    if (e == 1) return A;
    if (e % 2 == 0) {
        matrix tmp = pow(A, e/2);
        return mul(tmp, tmp);
    } else {
        matrix tmp = pow(A, e-1);
        return mul(A, tmp);
    }
}

class P8XCoinChange {
public:
    int solve(long long coins_sum, vector<long long> values) {
        int n = values.size();
        vector<matrix> M(n);
        M[0] = Zero(n);
        for (int i = 0; i < n; i++) {
            M[0][i][0] = 1;
        }
        for (int i = 1; i < n; i++) {
            ll e = values[i]/values[i-1];
            M[i] = pow(M[i-1], e);
            for (int j = i; j < n; j++) M[i][j][i]++;
        }
        matrix total = identity(n);
        for (int i = n-1; i >= 0; i--) {
            ll p = coins_sum/values[i];
            coins_sum %= values[i];
            total = mul(total, pow(M[i], p));
        }
        ll ans = 0;
        for (int i = 0; i < n; i++) (ans += total[n-1][i]) %= MOD;
        return ans;
    }
};