mayoko’s diary

プロコンとかいろいろ。

SRM 490 div1 hard: InfiniteLab

解法

実際の迷路は無限に縦に続いていますが, 入力で与えられる一つの迷路のことを「パターン」と呼ぶことにします。

まず, r1 と r2 が同じパターンに属する場合を考えます。一つ注意というか, そこがこの問題を難しくしているところなのですが, r1 と r2 が同じパターンに属しているからと言って, 一つのパターンだけを通って (r1, c1) から (r2, c2) にたどり着けるわけではありません。

多分上下 20 個ずつ同じパターンを用意すれば良いような気がするのですが, 下のコードでは上下に同じパターンを 100 個ずつ用意して, (r1, c1) から (r2, c2) への最短経路を幅優先探索で求めています。

はい, 同じパターンはこれで大丈夫です。次に r1 と r2 が別のパターンに属する場合ですが, とりあえず 0 <= r1 < H < r2 としておきます。r1 から r2 に行くためには, 必ず H, 2*H, ..., n*H (n = r2/H)の行を通らなければならない, ということに注目します。(k*H, i) から ((k+1)*H, j) までの距離は (0, i) から (H, j) までの距離と等しいことを考慮すると, (r1, c1) から (r2, c2) まで行くための経路は, 適当な i, j, k, ... を使って, ((r1, c1) から (H, i) までの距離) + ((0, i) から (H, j) までの距離) + ... + ((0, k) から (r2%H, c2) までの距離) と 書くことが出来ます。

ところで, mat[i][j] = ((0, i) から (H, j) にたどり着くための最短経路) とすれば, ((0, i) から (n*H, j) にたどり着くための最短経路) は, 行列累乗っぽく解けます。具体的には, C = A*B とすると, A*B の演算を, C[i][j] = (A[i][k] + B[k][j] の min 値) というようにすれば良いです。

この行列の累乗の計算を利用して, 答えは, dist[r1%H][c1][H][i] + mat[i][j] + dist[0][j][r2%H][c2] の (i, j) に関する最小値, と書けます。ただし, dist[y1][x1][y2][x2] は, ((y1, x1) から (y2, x2) までの最短距離), です。

typedef long long number;
typedef vector<number> vec;
typedef vector<vec> matrix;

const ll INF = 1ll<<60;

matrix identity(int n) {
    matrix A(n, vec(n));
    for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) {
        A[i][j] = INF;
        if (i == j) A[i][j] = 0;
    }
    return A;
}
// O( n^3 )
matrix mul(const matrix &A, const matrix &B) {
    matrix C(A.size(), vec(B[0].size()));
    for (int i = 0; i < (int)C.size(); ++i) {
        for (int j = 0; j < (int)C[i].size(); ++j) {
            C[i][j] = INF;
            for (int k = 0; k < (int)A[i].size(); ++k) {
                C[i][j] = min(C[i][j], A[i][k]+B[k][j]);
            }
        }
    }
    return C;
}
// O( n^3 log e )
matrix pow(const matrix &A, ll e) {
    if (e == 0) return identity(A.size());
    if (e == 1) return A;
    if (e % 2 == 0) {
        matrix tmp = pow(A, e/2);
        return mul(tmp, tmp);
    } else {
        matrix tmp = pow(A, e-1);
        return mul(A, tmp);
    }
}

class InfiniteLab {
public:
    long long getDistance(vector <string> field, long long r1, int c1, long long r2, int c2) {
        int H = field.size();
        int W = field[0].size();
        const int X = 100;
        int tH = H*(2*X+1);
        // teleports に関するメモ
        vector<vi> teleports(H);
        for (int i = 0; i < H; i++) {
            for (int j = 0; j < W; j++) {
                if (field[i][j] == 'T') {
                    if (teleports[i].empty()) {
                        teleports[i].resize(2);
                        teleports[i][0] = j;
                    } else {
                        teleports[i][1] = j;
                    }
                }
            }
        }
        // pattern 0 からの距離メモ
        vector<vector<ll> > dist0[22][22];
        for (int i = 0; i < H; i++) for (int j = 0; j < W; j++) {
            dist0[i][j].resize(H+1);
            for (int k = 0; k <= H; k++) dist0[i][j][k].resize(W);
        }
        // pattern 0 の各頂点について bfs
        vector<vector<ll> > dist(tH, vector<ll>(W));
        for (int i = 0; i < H; i++) {
            for (int j = 0; j < W; j++) {
                int ox = j;
                int oy = i+X*H;
                for (int y = 0; y < tH; y++) {
                    for (int x = 0; x < W; x++) {
                        dist[y][x] = INF;
                    }
                }
                queue<pii> que;
                dist[oy][ox] = 0;
                if (field[oy%H][ox] != '#') {
                    que.push(pii(oy, ox));
                    while (!que.empty()) {
                        pii p = que.front(); que.pop();
                        int y = p.first, x = p.second;
                        for (int k = 0; k < 4; k++) {
                            int nx = x+dx[k], ny = y+dy[k];
                            if (nx < 0 || nx >= W || ny < 0 || ny >= tH) continue;
                            if (field[ny%H][nx] == '#' || dist[ny][nx] <= dist[y][x] + 1) continue;
                            dist[ny][nx] = dist[y][x] + 1;
                            que.push(pii(ny, nx));
                        }
                        if (field[y%H][x] == 'T') {
                            for (int t2 = 0; t2 < 2; t2++) {
                                int x2 = teleports[y%H][t2];
                                if (x != x2) {
                                    if (dist[y][x2] > dist[y][x]+1) {
                                        dist[y][x2] = dist[y][x]+1;
                                        que.push(pii(y, x2));
                                    }
                                }
                            }
                        }
                    }
                }
                for (int y2 = 0; y2 <= H; y2++) {
                    for (int x2 = 0; x2 < W; x2++) {
                        dist0[i][j][y2][x2] = dist[y2+X*H][x2];
                    }
                }
            }
        }
        if (r1 > r2) {
            swap(r1, r2);
            swap(c1, c2);
        }
        ll diff = r2-r1;
        int fr1, fr2;
        if (r1 < 0) fr1 = (int)((H-((-r1)%H))%H);
        else fr1 = r1%H;
        r1 = fr1;
        r2 = diff+r1;
        fr2 = (int)(r2%H);

        ll ans = INF;
        if (r1/H == r2/H) {
            ans = dist0[fr1][c1][fr2][c2];
        } else {
            ll a1 = r1/H, a2 = r2/H;
            ll p = a2-a1-1;
            matrix A(W, vec(W));
            for (int i = 0; i < W; i++) for (int j = 0; j < W; j++) {
                A[i][j] = dist0[0][i][H][j];
            }
            A = pow(A, p);
            for (int i = 0; i < W; i++) for (int j = 0; j < W; j++) {
                ans = min(ans, dist0[fr1][c1][H][i] + A[i][j] + dist0[0][j][fr2][c2]);
            }
        }
        if (ans >= INF) ans = -1;
        return ans;
    }
};