mayoko’s diary

プロコンとかいろいろ。

AtCoder Regular Contest 050 D - Suffix Concat

解法

問題設定が以前解いた問題に似ています。
mayokoex.hatenablog.com

これから, 文字列のソートの仕方は, p < q ⇔ p+q < q+p とすれば良いことはわかるのですが, 単純にソートすると計算量が O(N^2 log N) みたいな感じになって間に合いません。

suffix array の lcp を使うと上手く行きそうな気がするのでそれを考えましょう。
suffix array 求める -> lcp[i] = (i, i+1 の共通接頭辞) を求める -> (i, j の共通接頭辞を高速で求められるようにする)
というところまでやっておきます(suffix array 関係でココらへんはよくやることっぽいので後で記事にするつもりです)。

そうすると, s.substr(lhs) と s.substr(rhs) の大小関係は, 以下のようにわかります(lhs < rhs としておく)。

まず, lhs と rhs の共通接頭辞の長さ len を求めます。ここで, rhs+len < N の場合は, s[lhs+len] と s[rhs+len] が違うということなので, この大小関係を見れば全体の大小関係もわかります(s の配列を直接見に行っても良いけど, 下のコードでは suffix array 配列の位置関係から見に行ってます)。

一方で, rhs+len == N の時は, s.substr(rhs)+s.substr(lhs) の方は, 前半部分(+ の前の部分) はすべて見たことになり, s.substr(lhs)+s.substr(rhs) の残りの部分は s.substr(lhs+len) + s.substr(rhs) となっています。よって, 次は (N-(lhs+len)) 文字分だけ, s.substr(lhs) と s.substr(lhs+len) を比較すべきに見えますが, 実は後は s.substr(lhs) と s.substr(lhs+len) を比較すれば OK です。

これはなんでかと言うと, s.substr(lhs) と s.substr(lhs+len) の N-(lhs+len) 文字が一致していたら, 残りの文字はすべて一致していることが確認できるからです。len = N-rhs が成り立つことから適当に逆算するとわかります。ということで, 後は suffix array で lhs と lhs+len を比較しましょう。

このような比較関数を作ってソートすれば OK です。

const int MAXN = 100010;
namespace SA {
    int rank[MAXN], tmp[MAXN];
    int n, k;
    bool compare_sa(int i, int j) {
        if (rank[i] != rank[j]) return rank[i] < rank[j];
        int ri = (i+k <= n) ? rank[i+k] : -1;
        int rj = (j+k <= n) ? rank[j+k] : -1;
        return ri < rj;
    }
    // suffix array を構築する
    // O(N log^2 N)
    // how to use: construct_saを呼ぶとsa配列にSuffixArrayを構築する
    void createSA(const string& s, int* sa) {
        n = s.size();
        // 最初は 1 文字ソート
        for (int i = 0; i <= n; i++) {
            sa[i] = i;
            rank[i] = i < n ? s[i] : -1;
        }
        for (k = 1; k <= n; k*=2) {
            sort(sa, sa+n+1, compare_sa);
            tmp[sa[0]] = 0;
            for (int i = 1; i <= n; i++) {
                tmp[sa[i]] = tmp[sa[i-1]] + (compare_sa(sa[i-1], sa[i]) ? 1 : 0);
            }
            for (int i = 0; i <= n; i++) rank[i] = tmp[i];
        }
    }
    namespace LCP {
        int rank[MAXN];
        // suffix array の情報をもとに longest common prefix を構築する
        // O(N)
        // hot to use: construct_lcpに文字列と suffix array の情報を入れると lcp 配列を作る
        void createLCP(const string& s, const int* sa, int* lcp) {
            int n = s.size();
            for (int i = 0; i <= n; i++) rank[sa[i]] = i;
            int h = 0;
            lcp[0] = 0;
            for (int i = 0; i < n; i++) {
                int j = sa[rank[i]-1];
                h = max(0, h-1);
                for (; j+h < n && i+h < n; h++) {
                    if (s[j+h] != s[i+h]) break;
                }
                lcp[rank[i]-1] = h;
            }
        }
        // sparse table を lcp 配列をもとに構築する
        // getLCP を呼ぶ前に読んでおかないとダメ
        // st[j][i] は lcp[i], lcp[i+1], ..., lcp[i+(2^j)-1] の最小値
        // lcp は (i, i+1) の共通接頭辞を求められるので, st[j][i] は (i, j) の共通接頭辞を求められる
        int st[21][MAXN];
        void initSparseTable(int n, const int* lcp) {
            int h = 1;
            while ((1<<h) < n) h++;
            for (int i = 0; i < n; i++) st[0][i] = lcp[i];
            for (int j = 1; j <= h; j++) {
                for (int i = 0; i <= n-(1<<j); i++) {
                    st[j][i] = min(st[j-1][i], st[j-1][i+(1<<(j-1))]);
                }
            }
        }
        inline int topBit(int t) {
            for (int i = 20; i >= 0; i--) {
                if ((1<<i)&t) return i;
            }
            return -1;
        }
        // suffix array の, "辞書順で" f 番目と s 番目の文字列における longest common prefix を求める
        // s.substr(f) と s.substr(s) の lcp じゃないからね
        int getLCP(int f, int s) {
            if (f > s) swap(f, s);
            int diff = topBit(s-f);
            return min(st[diff][f], st[diff][s-(1<<diff)]);
        }
    } // namespace LCP
} // namespace SA


int sa[MAXN], lcp[MAXN];
int invSA[MAXN], perm[MAXN];
int N;

bool compare(int lhs, int rhs) {
    bool flip = false;
    if (lhs > rhs) {
        flip = true;
        swap(lhs, rhs);
    }
    int len = SA::LCP::getLCP(invSA[lhs], invSA[rhs]);
    if (rhs+len < N) return flip ^ (invSA[lhs] < invSA[rhs]);
    return flip ^ (invSA[lhs+len] < invSA[lhs]);
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    cin >> N;
    string s;
    cin >> s;
    SA::createSA(s, sa);
    SA::LCP::createLCP(s, sa, lcp);
    SA::LCP::initSparseTable(N, lcp);
    for (int i = 0; i <= N; i++) invSA[sa[i]] = i;
    for (int i = 0; i < N; i++) perm[i] = i;
    sort(perm, perm+N, compare);
    for (int i = 0; i < N; i++) cout << perm[i]+1 << endl;
    return 0;
}