mayoko’s diary

プロコンとかいろいろ。

SRM 664 div1 med:BearAttacks

前言ったかもですがmedはたいてい「気付きの『けり』」が2つ必要な印象。
今回は1つも気づきませんでした。

解法

まず気づかないといけないのは,「rigionの大きさの2乗の合計」は,「経路がつながっている頂点の(順序付きの)ペアの数」で置き換えられるということです(nico_shindanninさんの説明がわかりやすいです)。

これを考慮して,答えを求めます。単純に考えるとペア同士がつながっている確率を各ペアについて求めるのでO(n^2)かかりますが,
dp[v] = (vを根とする部分木のすべての頂点u(vも含む)に対して,vとuがつながっている確率,の総和)として,これを前計算しておくと,O(n)で解くことができるようになります。

まずdpの求め方を考えます。あ,その前に頂点vが生き残る確率prob[v]を求めておきます(コード中ではinverse[v]と書いてる)。最後にN!掛け算することを考えると,doubleで\frac{1}{v+1}と求めておくより整数問題的にv+1の逆元を求めておくのが吉です。要するに,

prob[v] = (v+1)^{-1} \mathrm{mod} 10^9+7

です。

dp[v]を求めます。頂点vの子をc_1, ..., c_mとすると,

dp[v] = prob[v] + dp(c_1) * prob[v] + ... + dp(c_m) * prob[v]

で求められます。解釈としては,ペア(v, v)の存在する期待値と,vから「vの子の部分木の頂点」がペアでつながっている数の期待値を足し算しているという感じですね。

これを使って答えを求めます。確認ですが求めたいのは「経路がつながっている頂点の(順序付きの)ペアの数」です。

まず頂点vと頂点vのペアがつながっている確率はprob[v]ですね。これをすべての頂点に対して答えに加えます。

次に頂点vを根とする部分木が他の頂点とつながっている個数を考えます。これは,各頂点に対して

dp[v] * prob[parent[v]]

dp[v] * (dp[parent[v]] - prob[parent[v]] * dp[v])

を計算して答えに加算していけばよいです。それぞれ,
parent[v]とvの部分木同士でのペアの個数の期待値
parent[v]の,v以外の子による部分木と,vの部分木でのペアの個数の期待値
を表しています。

ちなみにこの説明最初に読んだ時はvの部分木と他の頂点ってparent[v]より上の頂点ともつながってるだろわけわからんと思ったんですが,それはparent[v]とかそれより上(親)の頂点で上の計算をしているときに考慮されています(これ思いつくの難しすぎじゃないですか)。

const int MAXN = 1000100;
const ll MOD = 1e9+7;

int parent[MAXN];
vector<int> child[MAXN];
ll inverse[MAXN];
ll dp[MAXN];

// extgcd
ll extgcd(ll a, ll b, ll& x, ll& y) {
    ll d = a;
    if (b != 0) {
        d = extgcd(b, a % b, y, x);
        y -= (a / b) * x;
    } else {
        x = 1; y = 0;
    }
    return d;
}

// mod_inverse
ll mod_inverse(ll a, ll m) {
    ll x, y;
    extgcd(a, m, x, y);
    return (m+x%m) % m;
}

ll dfs(int v) {
    if (dp[v] >= 0) return dp[v];
    dp[v] = inverse[v];
    for (int ch : child[v]) {
        (dp[v] += dfs(ch) * inverse[v]) %= MOD;
    }
    return dp[v];
}

class BearAttacks {
public:
    int expectedValue(int N, int R0, int A, int B, int M, int LOW, int HIGH) {
        for (int i = 0; i < N; i++) child[i].clear();
        ll R = R0;
        ll BILLION = 1000000000;
        for (int i = 1; i <= N-1; i++) {
            R = ((ll)A*R+B) % M;
            ll MIN = ((ll)(i-1)*LOW) / BILLION;
            ll MAX = ((ll)(i-1)*HIGH) / BILLION;
            int tmp = (int)(MIN+R%(MAX-MIN+1));
            parent[i] = tmp;
            child[tmp].push_back(i);
        }
        for (int i = 1; i < N; i++) {
            cout << i << "  " << parent[i] << endl;
        }
        for (int i = 0; i < N; i++) inverse[i] = mod_inverse(i+1, MOD);
        memset(dp, -1, sizeof(dp));
        dfs(0);
        ll ret = 0;
        for (int i = 0; i < N; i++) (ret += inverse[i]) %= MOD;
        for (int i = 1; i < N; i++) {
            (ret += inverse[parent[i]] * dp[i]) %= MOD;
            ret += dp[i] * (dp[parent[i]] - dp[i] * inverse[parent[i]] % MOD + MOD);
            ret %= MOD;
        }
        ret %= MOD;
        for (ll i = 1; i <= N; i++) (ret *= i) %= MOD;
        return ret;
    }
};