mayoko’s diary

プロコンとかいろいろ。

VK Cup 2016 - Round 1 D. Bear and Contribution

解法

まず単純な解法を考えてみます。

目標になる値を決めると, 各値(vals という配列にまとめておくことにします)を目標の値にするために必要なコストが決定します(t だけ contribution を増やしたいとすると, t/5 * c1 + (t%5)*c がかかるコストの最小値(c1 は 5 増やすのに必要なコスト, c は 1 増やすのに必要なコスト))。

なので, n 人それぞれの contribution を調べれば, 各目標値に対して O(N) でコストを調べることができるので, 時間はかかるけど答えは得られそうです。また, 目標値としてあり得る値は, t, t+1, t+2, t+3, t+4 の 5 通りのみを調べれば良いこともわかります(t+5 をやるくらいなら t を目標値にしたほうが得)。これで完全に O(N^2) になりました。

ここで, O(N) 解法にするために, 微妙にしゃくとりっぽいことをやることを考えます(直後に書くことは嘘解法です)。
vals, 目標値をまとめた配列はソートしておきましょう。で, vals に対して [lp, rp] という区間を持っておきます。目標値が増えるごとに, rp を増加させて, vals[rp] <= (目標値) を満たす最大の rp を探します。で, [lp, rp] に含まれる区間が k 個になるように, lp を増加させます。

この考えはなんとなく行けそうな気がしますが, 実際にはダメです。何がダメなのかと言うと,

  • rp を増やした時に増えるコストは計算出来るが, lp を増やして要素を減らすときのコストがどれだけ減少するかが計算できない
  • そもそも [lp, rp] という区間が最適であるとは限らない(5 増やすほうが 4 増やすよりコストが軽いことがあるので)

という点です。

上で書いたコストの計算式から, 5 の倍数ごとにはコストは正確に計算できます。そこで, 目標値を持つベクトルは, 5 で割った余りごとに分類しておきましょう。このようにすると, 1 つ目の問題は解決できます。

また, 2 つ目の問題については, vals の値を 5 で割った余りごとに分類しておくと, 5 で割った余り同士では最も小さいものを取り除いていくのが最適なことは言えるので, 「5 で割った余りが一番小さいものを比べて, 目標値に達するためのコストが最大のものを取り除く」という戦略が出来ます。これは queue を使うと実装しやすいです。

vector<ll> target[5];
ll vals[200200];
const ll INF = 1e9+3;

inline ll calc(ll to, ll from, ll c, ll c1) {
    return ((to-from)%5)*c + (to-from)/5 * c1;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int n, k, b, c;
    cin >> n >> k >> b >> c;
    ll c1 = min(b, c*5);
    for (int i = 0; i < n; i++) {
        ll t;
        cin >> t;
        t += INF;
        vals[i] = t;
    }
    sort(vals, vals+n);
    for (int i = k-1; i < n; i++) {
        for (int j = 0; j < 5; j++) {
            ll tmp = vals[i]+j;
            target[tmp%5].push_back(tmp);
        }
    }
    for (int i = 0; i < 5; i++) {
        sort(target[i].begin(), target[i].end());
        target[i].erase(unique(target[i].begin(), target[i].end()), target[i].end());
    }
    ll ans = INF*INF;
    for (int i = 0; i < 5; i++) {
        int sz = target[i].size();
        ll sum = 0;
        vector<queue<ll> > que(5);
        // vals の見てる場所, いくつ入っているか
        int vp = 0, cnt = 0;
        for (int j = 0; j < sz; j++) {
            if (j>0) sum += cnt * c1 * (target[i][j]-target[i][j-1]) / 5;
            while (vp < n && vals[vp] <= target[i][j]) {
                cnt++;
                ll t = vals[vp++];
                que[t%5].push(t);
                sum += calc(target[i][j], t, c, c1);
            }
            while (cnt > k) {
                // 1 個減らす
                ll maxi = 0;
                int index = 0;
                for (int l = 0; l < 5; l++) {
                    if (!que[l].empty()) {
                        ll tmp = calc(target[i][j], que[l].front(), c, c1);
                        if (maxi < tmp) {
                            maxi = tmp;
                            index = l;
                        }
                    }
                }
                sum -= maxi;
                que[index].pop();
                cnt--;
            }
            if (cnt >= k) {
                ans = min(ans, sum);
            }
        }
    }
    cout << ans << endl;
    return 0;
}