mayoko’s diary

プロコンとかいろいろ。

Educational Codeforces Round 3 D. Gadgets for dollars and pounds

解法

冷静に考えるとちょっと前準備をしてから答えに対して二分探索すれば良かったと思うんですが, 答えを 1 つずつ足していって, それが条件を満たすかを 3 分探索で調べる, という方法で AC しました。

ある日程 ans について, i <= ans を満たす最小の a[i] および j <= ans を満たす最小の b[j] を求めておきます。おもちゃを買うときは, この a[i], b[j] を使って買い物するのが明らかに得です。

あとは, ドルでいくつの品物を買うか, および ポンドでいくつの品物を買うかを決定したいですが, ドルで買う品物の数を x とすると, ポンドで買う品物の数は k-x です。
品物の買い方は明らかに値段の少ないものから買うのが最適ですが, この時,

  • ドルで買う品物の値段の合計は, x=0 のときは 0 円で, だんだん値段の増え方が上昇していく凸関数
  • ポンドで買う品物の値段の合計は, x=0 の時最大で, だんだん値段の減り方が減少していく凸関数

であり, 凸関数 + 凸関数 は凸関数なので, 3 分探索をすることで最適な値 x を得ることが出来ます。

整数の 3 分探索初めて書いた気がするので注意点を少し書いておくと,

  • while 終了条件は high-low <= 2
  • 区間の扱い方は, 「最適な値が [low, high] にある」と考えながらやっているので, while が終了した後, low, (low+high)/2, high の 3 つの値のどれが最適かを調べなければならない

ということですかね。あとコンテスト中(virtual participation だけど)にラムダ式を使ったのも初めてでした。引数がいろいろあってめんどいときも [&] と書けば OK なので便利ですね。

const int MAXN = 200020;
const ll INF = 1ll<<55;
int a[MAXN], b[MAXN];

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int n, m, k, s;
    cin >> n >> m >> k >> s;
    for (int i = 0; i < n; i++) cin >> a[i];
    for (int i = 0; i < n; i++) cin >> b[i];
    vector<pii> item[2];
    for (int i = 0; i < m; i++) {
        int t, c;
        cin >> t >> c;
        item[t-1].emplace_back(c, i+1);
    }
    vector<ll> costSum[2];
    for (int i = 0; i < 2; i++) {
        sort(item[i].begin(), item[i].end());
        int size = item[i].size();
        costSum[i].resize(size+1);
        for (int j = 0; j < size; j++) {
            costSum[i][j+1] = costSum[i][j]+item[i][j].first;
        }
    }
    int ans = 0, besta = 0, bestb = 0;
    for (; ans < n; ans++) {
        if (a[besta] > a[ans]) besta = ans;
        if (b[bestb] > b[ans]) bestb = ans;
        int low = 0, high = min<int>(item[0].size(), k);
        auto f = [&](int x) -> ll {
            if (k > (int)(x+item[1].size())) return INF;
            return costSum[0][x] * a[besta] + costSum[1][k-x] * b[bestb];
        };
        while (high-low > 2) {
            int m1 = (2*low+high)/3, m2 = (low+2*high)/3;
            ll M1 = f(m1), M2 = f(m2);
            if (M1 < M2) high = m2;
            else low = m1;
        }
        int best = low;
        if (f(best) > f(high)) best = high;
        if (f(best) > f((high+low)/2)) best = (high+low)/2;
        if (f(best) <= s) {
            cout << ans+1 << endl;
            for (int i = 0; i < best; i++) {
                cout << item[0][i].second << " " << besta+1 << endl;
            }
            for (int i = 0; i < k-best; i++) {
                cout << item[1][i].second << " " << bestb+1 << endl;
            }
            return 0;
        }
    }
    cout << -1 << endl;
    return 0;
}