mayoko’s diary

プロコンとかいろいろ。

SRM 540 div1 med: RandomColoring

解法わかっても結構エグイ問題。

問題

TopCoder Statistics - Problem Statement

0, 1, ..., N-1 番目の看板に順番に色を塗っていく。塗る色は RGB 値で構成されている。i 番目と i+1 番目の看板の間で塗る色には d1, d2 で表される制約があって, i 番目の看板に塗る色の RGB 値をそれぞれ r, g, b, i+1 番目の看板に塗る色の RGB 値をそれぞれ R, G, B とすると,

  • |R-r| >= d2, |G-g| >= d2, |B-b| >= d2 がすべて成り立つ
  • |R-r| <= d1, |G-g| <= d1, |B-b| <= d1 のうち少なくとも一つが成り立つ

が満たされなければならない。
i 番目から i+1 番目の看板を塗るときは, (r, g, b) から上記の条件を満たす色のうちいずれかを等確率で選択する。

このやり方で塗っていくと, 0 <-> 1, 1 <-> 2, ..., N-2 <-> N-1 の間では制約を満たすが, N-1 <-> 0 の間では制約を満たすとは限らない(看板は円周上に並んでいるので 0 と N-1 も隣り合っている設定)。
上の塗り方をしていった時に 0 <-> N-1 の間で制約を満たしていない確率を求めよ。ただし, 0 番目の看板の RGB 値は決まっている。

解法

普通に考えると,
dp[i][r][g][b] = (i 番目の看板の RGB 値が r, g, b になる確率) という dp をどんどん更新していく感じになりますが, 更新に O(RGB) かかるので, 単純な遷移では間に合いません。そこで, 多次元いもす法を使います。
いもす法 - いもす研 (imos laboratory)

今回の場合は, 一辺の長さ 2*d2 の立方体の真ん中から, 一辺の長さ 2*d1 の立方体をくりぬくようなイメージで一様に値を足していけば良いです(こうすると, 全部距離が d2 以内になって, かつ次の RGB 値がすべて d1 以内の差で遷移することが防げる)。

3 次元いもす自体はめんどくさそうに見えて実はそれほどめんどくさくないです。
いもす法解説ページに書いてある「特殊な座標系への拡張」をやって確認してみるとわかりますが, 結局 pairty のようなものを考えて parity が奇数ならマイナス, 偶数ならプラスするようになります。

double dp[2][55][55][55];

class RandomColoring {
public:
    double getProbability(int N, int maxR, int maxG, int maxB, int startR, int startG, int startB, int d1, int d2) {
        d1--;
        // 初期化
        for (int i = 0; i < 2; i++) for (int j = 0; j < 55; j++) {
            for (int k = 0; k < 55; k++) for (int l = 0; l < 55; l++) {
                dp[i][j][k][l] = 0;
            }
        }
        dp[0][startR][startG][startB] = 1;
        for (int i = 0; i < N-1; i++) {
            int cur = i%2, tar = cur^1;
            // 初期化
            for (int r = 0; r < maxR; r++) for (int g = 0; g < maxG; g++) for (int b = 0; b < maxB; b++) {
                dp[tar][r][g][b] = 0;
            }
            // imos
            for (int r = 0; r < maxR; r++) for (int g = 0; g < maxG; g++) for (int b = 0; b < maxB; b++) {
                int lr = max(r-d2, 0), rr = min(r+d2, maxR-1)+1;
                int lg = max(g-d2, 0), rg = min(g+d2, maxG-1)+1;
                int lb = max(b-d2, 0), rb = min(b+d2, maxB-1)+1;
                int lr1 = max(r-d1, 0), rr1 = min(r+d1, maxR-1)+1;
                int lg1 = max(g-d1, 0), rg1 = min(g+d1, maxG-1)+1;
                int lb1 = max(b-d1, 0), rb1 = min(b+d1, maxB-1)+1;
                int V2 = (rr-lr)*(rg-lg)*(rb-lb);
                int V1 = (rr1-lr1)*(rg1-lg1)*(rb1-lb1);
                if (d1==-1) V1 = 0;
                if (V2-V1==0) continue;
                double p = dp[cur][r][g][b]/(V2-V1);

                dp[tar][lr][lg][lb] += p;
                dp[tar][lr][lg][rb] -= p;
                dp[tar][lr][rg][lb] -= p;
                dp[tar][lr][rg][rb] += p;
                dp[tar][rr][lg][lb] -= p;
                dp[tar][rr][lg][rb] += p;
                dp[tar][rr][rg][lb] += p;
                dp[tar][rr][rg][rb] -= p;

                if (d1 != -1) {
                    p *= -1;
                    dp[tar][lr1][lg1][lb1] += p;
                    dp[tar][lr1][lg1][rb1] -= p;
                    dp[tar][lr1][rg1][lb1] -= p;
                    dp[tar][lr1][rg1][rb1] += p;
                    dp[tar][rr1][lg1][lb1] -= p;
                    dp[tar][rr1][lg1][rb1] += p;
                    dp[tar][rr1][rg1][lb1] += p;
                    dp[tar][rr1][rg1][rb1] -= p;
                }
            }
            for (int r = 0; r < maxR; r++) for (int g = 0; g < maxG; g++) for (int b = 0; b < maxB; b++) {
                dp[tar][r][g][b+1] += dp[tar][r][g][b];
            }
            for (int r = 0; r < maxR; r++) for (int b = 0; b < maxB; b++) for (int g = 0; g < maxG; g++) {
                dp[tar][r][g+1][b] += dp[tar][r][g][b];
            }
            for (int g = 0; g < maxG; g++) for (int b = 0; b < maxB; b++) for (int r = 0; r < maxR; r++) {
                dp[tar][r+1][g][b] += dp[tar][r][g][b];
            }
        }
        double ret = 0;
        for (int r = 0; r < maxR; r++) for (int g = 0; g < maxG; g++) for (int b = 0; b < maxB; b++) {
            int dr = abs(r-startR);
            int dg = abs(g-startG);
            int db = abs(b-startB);
            if (dr <= d2 && dg <= d2 && db <= d2 && (dr > d1 || dg > d1 || db > d1)) ret += dp[(N-1)%2][r][g][b];
        }
        return 1-ret;
    }
};