SRM 490 div1 hard: InfiniteLab
解法
実際の迷路は無限に縦に続いていますが, 入力で与えられる一つの迷路のことを「パターン」と呼ぶことにします。
まず, r1 と r2 が同じパターンに属する場合を考えます。一つ注意というか, そこがこの問題を難しくしているところなのですが, r1 と r2 が同じパターンに属しているからと言って, 一つのパターンだけを通って (r1, c1) から (r2, c2) にたどり着けるわけではありません。
多分上下 20 個ずつ同じパターンを用意すれば良いような気がするのですが, 下のコードでは上下に同じパターンを 100 個ずつ用意して, (r1, c1) から (r2, c2) への最短経路を幅優先探索で求めています。
はい, 同じパターンはこれで大丈夫です。次に r1 と r2 が別のパターンに属する場合ですが, とりあえず 0 <= r1 < H < r2 としておきます。r1 から r2 に行くためには, 必ず H, 2*H, ..., n*H (n = r2/H)の行を通らなければならない, ということに注目します。(k*H, i) から ((k+1)*H, j) までの距離は (0, i) から (H, j) までの距離と等しいことを考慮すると, (r1, c1) から (r2, c2) まで行くための経路は, 適当な i, j, k, ... を使って, ((r1, c1) から (H, i) までの距離) + ((0, i) から (H, j) までの距離) + ... + ((0, k) から (r2%H, c2) までの距離) と 書くことが出来ます。
ところで, mat[i][j] = ((0, i) から (H, j) にたどり着くための最短経路) とすれば, ((0, i) から (n*H, j) にたどり着くための最短経路) は, 行列累乗っぽく解けます。具体的には, C = A*B とすると, A*B の演算を, C[i][j] = (A[i][k] + B[k][j] の min 値) というようにすれば良いです。
この行列の累乗の計算を利用して, 答えは, dist[r1%H][c1][H][i] + mat[i][j] + dist[0][j][r2%H][c2] の (i, j) に関する最小値, と書けます。ただし, dist[y1][x1][y2][x2] は, ((y1, x1) から (y2, x2) までの最短距離), です。
typedef long long number; typedef vector<number> vec; typedef vector<vec> matrix; const ll INF = 1ll<<60; matrix identity(int n) { matrix A(n, vec(n)); for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) { A[i][j] = INF; if (i == j) A[i][j] = 0; } return A; } // O( n^3 ) matrix mul(const matrix &A, const matrix &B) { matrix C(A.size(), vec(B[0].size())); for (int i = 0; i < (int)C.size(); ++i) { for (int j = 0; j < (int)C[i].size(); ++j) { C[i][j] = INF; for (int k = 0; k < (int)A[i].size(); ++k) { C[i][j] = min(C[i][j], A[i][k]+B[k][j]); } } } return C; } // O( n^3 log e ) matrix pow(const matrix &A, ll e) { if (e == 0) return identity(A.size()); if (e == 1) return A; if (e % 2 == 0) { matrix tmp = pow(A, e/2); return mul(tmp, tmp); } else { matrix tmp = pow(A, e-1); return mul(A, tmp); } } class InfiniteLab { public: long long getDistance(vector <string> field, long long r1, int c1, long long r2, int c2) { int H = field.size(); int W = field[0].size(); const int X = 100; int tH = H*(2*X+1); // teleports に関するメモ vector<vi> teleports(H); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) { if (field[i][j] == 'T') { if (teleports[i].empty()) { teleports[i].resize(2); teleports[i][0] = j; } else { teleports[i][1] = j; } } } } // pattern 0 からの距離メモ vector<vector<ll> > dist0[22][22]; for (int i = 0; i < H; i++) for (int j = 0; j < W; j++) { dist0[i][j].resize(H+1); for (int k = 0; k <= H; k++) dist0[i][j][k].resize(W); } // pattern 0 の各頂点について bfs vector<vector<ll> > dist(tH, vector<ll>(W)); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) { int ox = j; int oy = i+X*H; for (int y = 0; y < tH; y++) { for (int x = 0; x < W; x++) { dist[y][x] = INF; } } queue<pii> que; dist[oy][ox] = 0; if (field[oy%H][ox] != '#') { que.push(pii(oy, ox)); while (!que.empty()) { pii p = que.front(); que.pop(); int y = p.first, x = p.second; for (int k = 0; k < 4; k++) { int nx = x+dx[k], ny = y+dy[k]; if (nx < 0 || nx >= W || ny < 0 || ny >= tH) continue; if (field[ny%H][nx] == '#' || dist[ny][nx] <= dist[y][x] + 1) continue; dist[ny][nx] = dist[y][x] + 1; que.push(pii(ny, nx)); } if (field[y%H][x] == 'T') { for (int t2 = 0; t2 < 2; t2++) { int x2 = teleports[y%H][t2]; if (x != x2) { if (dist[y][x2] > dist[y][x]+1) { dist[y][x2] = dist[y][x]+1; que.push(pii(y, x2)); } } } } } } for (int y2 = 0; y2 <= H; y2++) { for (int x2 = 0; x2 < W; x2++) { dist0[i][j][y2][x2] = dist[y2+X*H][x2]; } } } } if (r1 > r2) { swap(r1, r2); swap(c1, c2); } ll diff = r2-r1; int fr1, fr2; if (r1 < 0) fr1 = (int)((H-((-r1)%H))%H); else fr1 = r1%H; r1 = fr1; r2 = diff+r1; fr2 = (int)(r2%H); ll ans = INF; if (r1/H == r2/H) { ans = dist0[fr1][c1][fr2][c2]; } else { ll a1 = r1/H, a2 = r2/H; ll p = a2-a1-1; matrix A(W, vec(W)); for (int i = 0; i < W; i++) for (int j = 0; j < W; j++) { A[i][j] = dist0[0][i][H][j]; } A = pow(A, p); for (int i = 0; i < W; i++) for (int j = 0; j < W; j++) { ans = min(ans, dist0[fr1][c1][H][i] + A[i][j] + dist0[0][j][fr2][c2]); } } if (ans >= INF) ans = -1; return ans; } };