mayoko’s diary

プロコンとかいろいろ。

SRM 593 div1 hard: WolfDelaymasterHard

解法

topcoder の解説を参考にしました。
http://apps.topcoder.com/wiki/display/tc/SRM+593

こっち読むほうが絶対わかりやすいので, 概要しか書きません。下に書いてあるのも, この解説に書いてある変数を普通に使っています。

普通に dp することを考えます。

dp[j] = (2*j 番目までの文字列までで, valid な文字列の数) とすると,
dp[j] の値は, 2*i 番目までの文字列は適当で, 残りの 2*(j-i) 文字は, wwwooo のように最初の j-i 文字を w, 後を o にする, とすれば, dp[j] = dp[i] (0 <= i < j) となります。
これだけなら, dpSum[i] とかいうの作ればええやんけwなんですが, 最後の 2*(j-i) 文字が valid であるかどうかは入力文字列に依存するので, そう簡単ではありません。ですが, 実は各 dp[j] が, dp[j] = dpSum[b[j]] - dpSum[a[j]] + alpha みたいな感じで書けます(つまり, j に遷移する valid な i が区間で書けるということ)。これを詳しく見ていきます。

まず, valid な文字列を 2 種類に分けます。

  • 前半に少なくとも 1 つの 'w' があり, 'o' は全くない。かつ, 後半には 'w' が全くない。
  • 前半はすべて '?'。後半には 'w' が全くない。

2*j 番目の文字列までで, 1 つ目の性質を満たすものを考えます。
maxW[j] = (j-1 以下で, 最後にある 'w' の位置)とすると, 2*i 目以降の 2*(j-i) 文字が valid になるための条件として, i があまりにも 0 に近すぎると, maxW[j] 番目の 'w' が後半側に来てしまって NG です。ちゃんと条件にすると, i <= maxW/2 が条件になります。
他に 3 つの不等式が出てきて, i が決まります(valid になる i の範囲が区間になることがわかる)。

2 つ目の性質を満たすものは, j を決めたら i がどうなるか〜と考えるのではなく, i を決めた時 どの j まで影響があるか〜と考えます。これでも, i-j が valid になる j が区間として出てきます。

なので, dp で出来るでしょ〜という流れですが, 実装がずっと上手くいかなくて悲しい気持ちになったので実装についても書いていきます。

今まで意識してなかったんですが,

cntW[0] = 0;
cntW[i+1] = cnt[i] + (s[i] == 'w');

というように書くと, cntW[i] には, 区間 [0, i) にある 'w' の数が出てきます。

よって, 例えば maxW だったら, ([x, 2*i) に含まれる 'w' の数) = cnt[2*i] - cnt[x] == 0 を満たす最大の x となるみたいな感じです。

で, これ意識すると dp と dpSum の関係も結構わかりやすくなりました。
今回の場合は, dpSum[i] = [0, i) の区間における, dp[j] の合計 となるように dpSum を組んでいます。なので, j における valid な i の範囲を a[j] <= i < b[j] と決めていれば, dp[j] = dpSum[b[j]] - dpSum[a[j]] でスッキリです。

ていうか常にこう書くのが良いかも?
dpSum[0] = 0, dpSum[1] = dp[0] = 1 と書いて, あとは dp[i] が [a, b) の区間和であるなら, dp[i] = dpSum[b] - dpSum[a] ってものすごくわかりやすい(なんかアタリマエかもしれないけど区間系で毎回バグりまくっているので基準が欲しい感じがあった)。

const ll MOD = 1e9+7;
const int MAXN = 1000100;

class WolfDelaymasterHard {
public:
    int countWords(int N, int wlen, int w0, int wmul, int wadd, int olen, int o0, int omul, int oadd) {
        string s;
        for (int i = 0; i < N; i++) s += '?';
        ll x = w0;
        for (int i = 0; i < wlen; i++) {
            s[x] = 'w';
            x = (x*wmul+wadd)%N;
        }
        x = o0;
        for (int i = 0; i < olen; i++) {
            s[x] = 'o';
            x = (x*omul+oadd)%N;
        }
        vector<int> cntW(N+1, 0), cntO(N+1, 0);
        for (int i = 1; i <= N; i++) {
            cntW[i] = cntW[i-1] + (s[i-1] == 'w');
            cntO[i] = cntO[i-1] + (s[i-1] == 'o');
        }
        // i から配る上限を求める([i, c[i]] に配る)
        vector<int> c(N/2);
        for (int i = 0; i < N/2; i++) {
            int low = i, high = N/2+1;
            while (high-low > 1) {
                int med = (high+low)/2;
                int cw = cntW[2*med]-cntW[2*i];
                int co = cntO[(2*med+2*i)/2]-cntO[2*i];
                if (cw == 0 && co == 0) low = med;
                else high = med;
            }
            c[i] = low;
        }
        // i が受け取る下限と上限を求める
        // [a, b)
        vector<int> a(N/2+1), b(N/2+1);
        for (int i = 1; i <= N/2; i++) {
            int maxW;
            {
                int low = -1, high = 2*i;
                while (high-low > 1) {
                    int med = (high+low)/2;
                    if (cntW[2*i]-cntW[med] < 1) high = med;
                    else low = med;
                }
                maxW = low;
            }
            a[i] = 2*N, b[i] = -1;
            if (maxW == -1) continue;
            // a
            a[i] = max(0, maxW-i+1);
            int maxO;
            {
                int low = 2*a[i]-1, high = maxW;
                while (high-low > 1) {
                    int med = (high+low)/2;
                    if (cntO[maxW]-cntO[med] < 1) high = med;
                    else low = med;
                }
                maxO = low;
            }
            a[i] = max(a[i], (maxO+2)/2);
            // b
            b[i] = maxW/2+1;
            int minO;
            {
                int low = maxW+1, high = 2*i;
                while (high-low > 1) {
                    int med = (high+low)/2;
                    if (cntO[med]-cntO[maxW] == 0) low = med;
                    else high = med;
                }
                minO = low;
            }
            b[i] = min(minO-i+1, b[i]);
        }
        // dp
        vector<ll> dp(N/2+10), dpSum(N/2+10), imos(N/2+10);
        ll plus = 0;
        dp[0] = 1; dpSum[1] = 1;
        if (c[0] > 0) {
            plus = dp[0];
            imos[c[0]] = MOD-1;
        }
        for (int i = 1; i <= N/2; i++) {
            dp[i] = plus;
            if (a[i] < b[i]) {
                dp[i] += dpSum[b[i]]-dpSum[a[i]]+MOD;
                dp[i] %= MOD;
            }
            dpSum[i+1] = (dpSum[i] + dp[i]) % MOD;
            (plus += imos[i]) %= MOD;
            if (c[i] > i && c[i] <= N/2) {
                (plus += dp[i]) %= MOD;
                (imos[c[i]] += MOD-dp[i]) %= MOD;
            }
        }
        return (int)dp[N/2];
    }
};