POJ 3074: Sudoku
解法
蟻本に書いてある通りだと思ったんですが, 3076 が遅すぎて解けない…
下のコードではいろいろ定数倍高速化してるんですが, これでやっと 700ms くらいです。
- __builtin_popcount を使う数が 1<<9 までなので配列に前計算
- 列で既に使っている数を rdone にまとめてき, 遷移するときは O(1) でこれらを更新する
- 候補が 1 つしかないところは onePoint 配列に覚えておき, dfs に入った際に候補がある場合はそれを優先的に処理する
int board[9][9]; int rdone[9], cdone[9], celldone[9]; int popcnt[1<<9]; int poppos[1<<9]; int onePoint[100]; int os, ot; bool dfs() { if (os < ot) { int y = onePoint[os]/9, x = onePoint[os]%9; ++os; int done = rdone[y] | cdone[x] | celldone[y/3*3+x/3]; if (popcnt[done] == 9) { --os; return false; } int k = poppos[done]; board[y][x] = k; rdone[y] ^= 1<<k; cdone[x] ^= 1<<k; celldone[y/3*3+x/3] ^= 1<<k; bool flag = dfs(); if (flag) return true; board[y][x] = -1; rdone[y] ^= 1<<k; cdone[x] ^= 1<<k; celldone[y/3*3+x/3] ^= 1<<k; --os; return false; } int t = ot; int maxi = 0, y = -1, x = -1, maxFlag = -1; bool finish = true; for (int i = 0; i < 9; i++) for (int j = 0; j < 9; j++) { //printf("%d %d\n", i, j); if (board[i][j] != -1) continue; finish = false; int done = rdone[i] | cdone[j] | celldone[i/3*3+j/3]; int cnt = popcnt[done]; //printf("%d\n", cnt); if (cnt==9) return false; if (maxi < cnt) { maxi = cnt; y = i, x = j; maxFlag = done; } else if (cnt==8) { onePoint[ot++] = i*9+j; } } if (finish) { for (int i = 0; i < 9; i++) for (int j = 0; j < 9; j++) printf("%d", board[i][j]+1); printf("\n"); return true; } for (int k = poppos[maxFlag]; k < 9; k++) { if (((maxFlag>>k)&1) == 0) { board[y][x] = k; rdone[y] ^= 1<<k; cdone[x] ^= 1<<k; celldone[y/3*3+x/3] ^= 1<<k; bool flag = dfs(); if (flag) return true; board[y][x] = -1; rdone[y] ^= 1<<k; cdone[x] ^= 1<<k; celldone[y/3*3+x/3] ^= 1<<k; } } ot = t; return false; } int main() { string s; for (int i = 0; i < 1<<9; i++) { for (int j = 0; j < 9; j++) if ((i>>j)&1) popcnt[i]++; } for (int i = 0; i < 1<<9; i++) { for (int j = 8; j >= 0; j--) if (((i>>j)&1) == 0) poppos[i] = j; } while (cin >> s) { if (s == "end") break; for (int i = 0; i < 9; i++) for (int j = 0; j < 9; j++) { board[i][j] = (s[i*9+j] == '.' ? -1 : (s[i*9+j]-'1')); } memset(rdone, false, sizeof(rdone)); memset(cdone, false, sizeof(cdone)); memset(celldone, false, sizeof(celldone)); os = ot = 0; for (int i = 0; i < 9; i++) for (int j = 0; j < 9; j++) if (board[i][j] != -1) { rdone[i] |= 1<<(board[i][j]); } for (int i = 0; i < 9; i++) for (int j = 0; j < 9; j++) if (board[j][i] != -1) { cdone[i] |= 1<<(board[j][i]); } for (int i = 0; i < 9; i++) { int top = i/3*3, left = (i%3)*3; for (int j = top; j < top+3; j++) for (int k = left; k < left+3; k++) { if (board[j][k] != -1) celldone[i] |= 1<<(board[j][k]); } } dfs(); } return 0; }