mayoko’s diary

プロコンとかいろいろ。

SRM 528 div1 med: SPartition

解法

半分全列挙します。

N を与えられる文字列 s の長さであるとして, n = N/2 とします。最後の n 文字を, red と blue のどちらにするかは O(2^n) で決められますが, この時 red の文字列のほうが長い場合は, blue の文字列の長さ分, red と blue の最後の文字列は一致している必要があります。

例を挙げます。s = "oxoxxxoo" という文字列を考えた時, 後半 4 文字の文字列は "xxoo" です。
ここで, red = "xoo", blue = "x" というように文字を選択したとすると, blue のほうが文字列が短いので red の最後の 1 文字と blue の最後の 1 文字が一致していなければ全体でも合うことはないですが, 今回の場合は一致していないので NG です。一方で, red = "xxo", blue = "o" とした場合は, 最後の文字が一致しているので OK です。

で, このように最後の文字が一致した場合は, 前半の n 文字でずれている文字分帳尻を合わせられれば良いです。上の例では, 前半部分で blue が 最後に "xx" をつけるようにすれば良いとわかります。

以上の情報を保持したいので, mp[diff][s] = (red の方が後半の文字数が長いかどうかのフラグが diff で, 前半で帳尻合わせをしないといけない文字列の内容が s であるようなものの場合の数) という配列を作ってメモっておきます。これを下のコードのように map でやると, 実行時間が 1.8 秒になります。実行時間も余裕で通したい場合は s を int で表現するなどして工夫しましょう。

メモしたら, 前半の文字の選び方も O(2^n) で全探索し, 前半でついた差を帳尻合わせできるような後半の文字の選び方が何通りあるかを調べていけば良いです。

class SPartition {
public:
    long long getCount(string s) {
        int N = s.size();
        int n = N/2;
        map<pair<string, int>, ll> mp;
        for (int i = 0; i < 1<<n; i++) {
            string red, blue;
            for (int j = 0; j < n; j++) {
                if ((i>>j)&1) red += s[j+n];
                else blue += s[j+n];
            }
            int rcnt = red.size(), bcnt = blue.size();
            int cnt = min(rcnt, bcnt);
            bool ok = true;
            for (int j = 0; j < cnt; j++) {
                if (red[rcnt-j-1] != blue[bcnt-j-1]) {
                    ok = false;
                    break;
                }
            }
            if (ok) {
                if (rcnt > bcnt) mp[make_pair(red.substr(0, rcnt-bcnt), 0)]++;
                else if (rcnt < bcnt) mp[make_pair(blue.substr(0, bcnt-rcnt), 1)]++;
                else mp[make_pair("", 2)]++;
            }
        }
        ll ans = 0;
        for (int i = 0; i < 1<<n; i++) {
            string red, blue;
            for (int j = 0; j < n; j++) {
                if ((i>>j)&1) red += s[j];
                else blue += s[j];
            }
            int rcnt = red.size(), bcnt = blue.size();
            int cnt = min(rcnt, bcnt);
            bool ok = true;
            for (int j = 0; j < cnt; j++) {
                if (red[j] != blue[j]) {
                    ok = false;
                    break;
                }
            }
            if (ok) {
                string t;
                int tmp;
                if (rcnt > bcnt) {
                    t = red.substr(cnt, rcnt-bcnt);
                    tmp = 1;
                } else if (rcnt < bcnt) {
                    t = blue.substr(cnt, bcnt-rcnt);
                    tmp = 0;
                } else {
                    t = "";
                    tmp = 2;
                }
                ans += mp[make_pair(t, tmp)];
            }
        }
        return ans;
    }
};