mayoko’s diary

プロコンとかいろいろ。

Codeforces Round #299 (Div. 1) B. Tavas and Malekas

ハッシュポテト食べたい。

解法

Z algorithmとかいうのがあるらしいですがハッシュで解決させてしまいました。

まずなんでハッシュ値が必要なのかの話から…

今回の問題ではy[i]とy[i+1]がp以上離れていたら特に何も問題なく文字列sの整合性が取れていることがわかります。この場合はy[i+1]-(y[i]+p)文字分は好きに文字を選んで良いのでこの数分だけ26を掛けあわせます(0からy[0]の間及びy[m-1]からnの間も好きに埋めてい良いことに注意。それをずっと忘れてWA出しまくった…)。

ということで問題はy[i]とy[i+1]が干渉し合っているときどうするかです。これは

「文字列pのy[i+1]-y[i]文字目から最後までと文字列pの0文字目からn-(y[i+1]-y[i])文字目が一致しているか」(適当に書いたので±1はずれてるかも)

という問題に帰着できます。文字列の数が多いのでなるべく高速に答えなければなりませんがこの要望に答えるのがハッシュです。


ということでどうやってハッシュを使ったのかの話に移ります。

最初に全体の文字列のハッシュ値を以下のようにして計算します。

 hash_n = \sum_0^n (255^i * s_i) mod (10^9+7)

 s_iは文字を数値に変換したやつです。このようにして計算すると,部分文字列のハッシュ値は以下のように簡単に求めることが出来ます。

 hash_{l, r} = hash_r - hash_{l-1} mod (10^9+7)
ここで hash_{l, r}は[l, r]の区間の文字のハッシュ値を表します。しかしこれだけの計算式だと単純に部分文字列を計算した場合の 255^l倍になってしまっています。本来なら 255^lの逆元を計算してちゃんと求めるべきなのでしょうが今回の場合は比較しなければならない文字列は
・文字列sの0番目から始まる文字列
・文字列sのl番目から始まる文字列
の2種類のみなのでひとつ目の方の文字列のハッシュ値 255^l倍すれば解決します。

以下ソースコード

typedef long long Hash;

const ll MOD = 1e9+7;
const Hash b1 = 1e9+7;
const Hash b2 = 1e9+9;
const int MAXM = 1000100;

Hash h[MAXM][2];
int y[MAXM];
Hash p255[MAXM][2];
ll p26[MAXM];

void init(const string& s) {
    int n = s.size();
    h[0][0] = h[0][1] = (int)(s[0]);
    for (int i = 1; i < n; i++) {
        h[i][0] = (h[i-1][0] + (int)(s[i]) * p255[i][0]) % b1;
        h[i][1] = (h[i-1][1] + (int)(s[i]) * p255[i][1]) % b2;
    }
}

// [l, r]
pair<Hash, Hash> getHash(int l, int r) {
    Hash h1, h2;
    if (l == 0) {
        h1 = h[r][0], h2 = h[r][1];
    } else {
        h1 = (h[r][0]+b1-h[l-1][0]) % b1;
        h2 = (h[r][1]+b2-h[l-1][1]) % b2;
    }
    return make_pair(h1, h2);
}

bool same(int x, int y, int n) {
    auto p1 = getHash(y-x, n-1);
    auto p2 = getHash(0, n-y+x-1);
    (p2.first *= p255[y-x][0]) %= b1;
    (p2.second *= p255[y-x][1]) %= b2;
    return p1 == p2;
}

string s;

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    p255[0][0] = p255[0][1] = 1;
    p26[0] = 1;
    for (int i = 1; i < MAXM; i++) {
        p255[i][0] = (p255[i-1][0] * 255) % b1;
        p255[i][1] = (p255[i-1][1] * 255) % b2;
        p26[i] = (p26[i-1] * 26) % MOD;
    }
    int n, m;
    cin >> n >> m;
    cin >> s;
    int p = s.size();
    ll ans = 1;
    for (int i = 0; i < m; i++) {
        cin >> y[i];
        y[i]--;
    }
    if (m == 0) {
        for (int i = 0; i < n; i++) {
            (ans *= 26) %= MOD;
        }
        cout << ans << endl;
        return 0;
    }
    init(s);
    for (int i = 0; i < m-1; i++) {
        if (y[i] + p <= y[i+1]) {
            (ans *= p26[(y[i+1]-(y[i]+p))]) %= MOD;
        } else {
            if (same(y[i], y[i+1], p)) {
                continue;
            } else {
                ans *= 0;
                break;
            }
        }
    }
    (ans *= p26[n-y[m-1]-p]) %= MOD;
    (ans *= p26[y[0]]) %= MOD;
    cout << ans << endl;
    return 0;
}