mayoko’s diary

プロコンとかいろいろ。

square869120Contest H - 3人の昼食 (The Lunch)

解法

半分全列挙 + 平面走査 で解きます。

半分の要素について, 残す食品が e 個あるという前提での残りの食品のわけかたを全列挙します。
square1001とE869120の合計値段の差を x 軸に, square1001とうさぎの合計値段の差を y 軸にとるように, 列挙した要素を並べてみます。前半の全列挙要素を A, 後半の全列挙要素を B としましょう。A, B はどちらも x, y についてソートしておきます(x 軸優先。要するに pair をソートする感じ)。

A の要素 A[i] に対して, abs(A[i].x + B[j].x) <= d となる要素 B[j] が, square1001とE869120の合計値段の差 を d 以下にする要素です。各 i について, j の範囲がどうなるかはしゃくとりっぽくすれば合計 O(N) で求めることが出来ます。

また, abs(A[i].y + B[j].y) <= d となっていなければなりませんが, これは Binary Indexed Tree を使って, 今考えている要素で, 左の条件を満たすようなものはいくつあるかを数えることが出来ます。

// 0-based Binary Indexed Tree
template<typename T> struct BIT {
    int max;
    vector<T> bit;
    BIT(int max) : max(max) {bit.resize(max+1);}
    // [0, i)
    T sum(int i) {
        T s = 0;
        while (i > 0) {
            s += bit[i];
            i ^= i&-i;
        }
        return s;
    }
    // 0-basedな座標iに値xを追加する
    void add(int i, T x) {
        ++i;
        while (i <= max) {
            bit[i] += x;
            i += i&-i;
        }
    }
    // [a, b)
    T sum(int a, int b) {
        return sum(b)-sum(a);
    }
    // sum(0, i) >= wとなる最小のiを求める 存在しなければmaxを返す
    int lb(T w) {
        if (w <= 0) return 0;
        int k = 1;
        while (k <= max) k <<= 1;
        int i = 0;
        for (; k > 0; k >>= 1) if (i+k <= max && bit[i+k] < w) {
            w -= bit[i+k];
            i += k;
        }
        return i+1;
    }
};

const int MAXN = 22;
int A[MAXN];
bool ng[MAXN];

vector<pll> calc(const vll& v, int e) {
    vector<pll> ans;
    int size = v.size();
    if (size <= e) return ans;
    int p = 1;
    for (int i = 0; i < size-e; i++) p *= 3;
    for (int s = 0; s < (1<<size); s++) {
        memset(ng, false, sizeof(ng));
        int cnt = 0;
        for (int i = 0; i < size; i++) {
            if ((s>>i)&1) {
                ng[i] = true;
                cnt++;
            }
        }
        if (cnt != e) continue;
        for (int i = 0; i < p; i++) {
            vll s(3);
            int memo = i;
            for (int j = 0; j < size; j++) {
                if (ng[j]) continue;
                s[memo%3] += v[j];
                memo /= 3;
            }
            ans.emplace_back(s[0]-s[1], s[0]-s[2]);
        }
    }
    sort(ans.begin(), ans.end());
    return ans;
}

ll only(const vector<pll>& P, int d) {
    int size = P.size();
    ll ret = 0;
    for (int i = 0; i < size; i++) {
        if (abs(P[i].first) <= d && abs(P[i].second) <= d) ret++;
    }
    return ret;
}

int lb(ll a, const vll& v) {
    return lower_bound(v.begin(), v.end(), a) - v.begin();
}

ll solve(const vector<pll>& A, const vector<pll>& B, ll d) {
    if (A.size() == 0 || B.size() == 0) return 0;
    vll v;
    for (pll p : B) v.push_back(p.second);
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    BIT<ll> bit(v.size()+3);
    int start = 0, end = -1;
    int Asize = A.size(), Bsize = B.size();
    for (int i = 0; i < Bsize; i++) {
        if (A[0].first+B[i].first < -d) start = i+1;
        if (A[0].first+B[i].first > d) {
            end = i;
            break;
        }
    }
    if (end == -1) end = Bsize;
    if (end == 0) return 0;
    ll ret = 0;
    for (int i = start; i < end; i++) bit.add(lb(B[i].second, v), 1);
    for (int i = 0; i < Asize; i++) {
        while (start > 0) {
            if (A[i].first+B[start-1].first >= -d) {
                start--;
                bit.add(lb(B[start].second, v), 1);
            } else break;
        }
        while (end > 0) {
            if (A[i].first+B[end-1].first > d) {
                end--;
                bit.add(lb(B[end].second, v), -1);
            } else break;
        }
        ll low = (-d) - A[i].second;
        ll high = d - A[i].second;
        int id1 = lb(low, v);
        int id2 = lb(high+1, v);
        ret += bit.sum(id2);
        ret -= bit.sum(id1);
    }
    return ret;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N, D, E;
    cin >> N >> D >> E;
    for (int i = 0; i < N; i++) cin >> A[i];
    int n = N/2;
    vector<ll> P1, P2;
    for (int i = 0; i < n; i++) P1.push_back(A[i]);
    for (int i = n; i < N; i++) P2.push_back(A[i]);
    vector<vector<pll> > q1(E+1), q2(E+1);
    for (int i = 0; i <= E; i++) q1[i] = calc(P1, i);
    for (int i = 0; i <= E; i++) q2[i] = calc(P2, i);
    ll ans = 0;
    for (int i = 0; i <= E; i++) {
        if (P2.size()+i <= E) ans += only(q1[i], D);
        if (P1.size()+i <= E) ans += only(q2[i], D);
    }
    for (int i = 0; i <= E; i++) {
        for (int j = 0; j <= E-i; j++) ans += solve(q1[i], q2[j], D);
    }
    cout << ans << endl;
    return 0;
}