mayoko’s diary

プロコンとかいろいろ。

SRM 684 div1 med: DivFree

解法

kmjp さんのブログを参考にしました。
kmjp.hatenablog.jp

包除原理(ってこれは言うのか)で解きます。

ok[i] = (i 個の数を並べた時の valid な数列の数), ng[i] = (i 個の数を並べた時の valid でない数列の数) とします。ok[i]+ng
[i] = k^i が成り立つので, ng[i] が求まればハッピーです。

ng[i] を求めるために, bad[j] = (j 個の連続した数列が ng になっているような場合の数) というのを数えておきます。例えば 12, 6, 3, 1 みたいなのは ng ですが, この数列は bad[4] を構成する数列の 1 つですね。bad[j] の j は 2^j <= k を満たす範囲しか取れないので, 20 程度まで考えれば OK です。

で, これを使って ng[i+1] を求めます。まず, ng[i]*k というのを考えると, これは

  • i 個並べるより前に invalid であることが決定しているものに適当な数を付け加えたもの
    • $$$(ngな数列)$$$$ + (数) みたいなやつ
  • i 個並べて今 ng な数列が並んでいる途中の数列
    • $$$$$$(ng な数列) + (数) みたいなやつ

の 2 通りが考えられます。ただ, これだと数えきれていないやつがあります。それは, 「i 番目の数から ng な数列が続いていくような数列」です。上の 2 つでは, もともと ng な数列が用意されていなければならないので, これは無視されているわけですね。

それを考慮するために, ok[i+1-2] * bad[2] というのを考えてみましょう。これで i, i+1 番目で ng な数列ができているものを考慮出来て万事解決…ではないです。これを考えると, 「i-1 番目の数から i+1 番目まで ng な数列が続くもの」も数えられて閉まっています。これは ng[i]*k で既に考慮しているものなので, ダブって数えていることになります。

そこで今度は ok[i+1-3] * bad[3] を引きます。すると今度は…という形で続いていきます。bad[j] の最大値が 5 の場合を見てみると,

  • 長さ 3 4 5 を考慮
  • 長さ 2 3 を加える
  • 長さ 3 4 を減らす
  • 長さ 4 5 を加える
  • 長さ 5 を減らす(j = 5 が最大なので, 6 も減らす, ということはない)

と処理していって, 結局「長さ 2 3 4 5 を考慮」したことになるので, 包除原理のように処理していけば OK ということになります。

const ll MOD = 1e9+7;
const int MAX = 50050;

ll ng[MAX], ok[MAX];
ll way[20][MAX];
ll bad[20];

class DivFree {
public:
    int dfcount(int n, int k) {
        memset(way, 0, sizeof(way));
        for (int i = 1; i <= k; i++) way[1][i] = 1;
        for (int i = 1; i < 20-1; i++) {
            for (int j = 1; j <= k; j++) {
                for (int l = 2; l*j <= k; l++) {
                    way[i+1][l*j] += way[i][j];
                }
            }
        }
        memset(bad, 0, sizeof(bad));
        for (int i = 0; i < 20; i++) {
            for (int j = 1; j <= k; j++) {
                bad[i] += way[i][j];
            }
            bad[i] %= MOD;
        }
        ll total = 1;
        ok[0] = 1, ng[0] = 0;
        for (int i = 0; i < n; i++) {
            (total *= k) %= MOD;
            ng[i+1] = ng[i]*k;
            for (int j = 2; j < 20; j++) {
                if (i+1-j < 0) break;
                int sgn = (j%2) ? -1 : 1;
                ng[i+1] += sgn*ok[i+1-j]*bad[j];
            }
            ng[i+1] %= MOD;
            ng[i+1] = (ng[i+1]+MOD)%MOD;
            ok[i+1] = (total-ng[i+1]+MOD) % MOD;
        }
        return (int)(ok[n]);
    }
};