mayoko’s diary

プロコンとかいろいろ。

yukicoder No.31 悪のミックスジュース

解法

まず, N >= V の場合はそれぞれのジュースを 1 リットルずつ買って適当に分配すれば良いです。
V > N の場合もとりあえずすべてのジュースを 1 リットルずつ買っておけば「使わない果物があってはいけない」という制約を考える必要がなくなるので, 買ったことにして, V から N を引いておきます。

問題を線形計画問題っぽく定式化すると,
min
 C_1p_1 + C_2p_2 + ... + C_Np_N
s.t.
 p_1 \geq p_2 \geq p_3 ... \geq p_N
 p_1 + p_2 + ... + p_N \geq V
 p_i-p_{i+1} \geq 0 (1 \leq i \leq N-1)
 p_i \geq 0 (1 \leq i \leq N)
となります。

ですが, このままだとよくわからないので更に変形していきます。 x_i = p_i - p_{i+1} とおくと,
 C_1p_1 + C_2p_2 + ... + C_Np_N
 = C_1(p_1-p_2) + (C_1 + C_2)p_2 + ... + C_Np_N
 = C_1(p_1-p_2) + (C_1 + C_2)(p_2-p_1) + (C_1+C_2+C_3)p_3
 = ...
 = C_1 x_1 + (C_1+C_2)x_2 + (C_1+C_2+C_3)x_3 + ... + (C_1+C_2+...+C_N)x_N
となります。同様に,
 p_1+p_2+...+p_N \geq V
という制約式も,
 x_1+2x_2+...+Nx_N \geq V
という制約式になります。

また,  x_i の定義から,  x_i \geq 0 が制約になります。よって, 問題は
min
 C_1 x_1 + (C_1+C_2)x_2 + (C_1+C_2+C_3)x_3 + ... + (C_1+C_2+...+C_N)x_N
s.t.
 x_1+2x_2+...+Nx_N \geq V
 x_i \geq 0 (1 \leq i \leq N)
となります。
これならなんかいけそうです。というか, この問題は yukicoder No.288 と同じように解くことが出来ます。mayokoex.hatenablog.com

No.288 と同じように, V が大きいことがネックですが, 鳩の巣原理を使うと,  (C_1+C_2+...+C_i)/i の値がなるべく大きい物を選ぶだけで  N^2 のところまでは決めることが出来ます。あとは動的計画法でがんばりましょう。

const int MAXN = 111;
const ll INF = 1ll<<60;
ll C[MAXN], S[MAXN];
pair<ll, int> P[MAXN];

ll dp[MAXN][MAXN*MAXN];

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N;
    ll V;
    cin >> N >> V;
    for (int i = 0; i < N; i++) {
        cin >> C[i];
    }
    S[0] = C[0];
    P[0] = make_pair(S[0], 1);
    for (int i = 1; i < N; i++) {
        S[i] = S[i-1] + C[i];
        P[i] = make_pair(S[i], i+1);
    }
    ll ans = S[N-1];
    V -= N;
    if (V <= 0) {
        cout << ans << endl;
        return 0;
    }
    sort(P, P+N, [](const pair<ll, int>& lhs, const pair<ll, int>& rhs) -> bool {return lhs.first*rhs.second < rhs.first*lhs.second;});
    // V の残りが N^2 になるまでは引いていっても良い
    ll ok = V-N*N;
    if (ok > 0) {
        ll num = ok/P[0].second;
        if (ok%P[0].second) num++;
        ans += P[0].first*num;
        V -= P[0].second*num;
    }
    assert(V <= N*N);
    // 後は dp
    for (int i = 0; i < N; i++) for (int j = 0; j <= N*N; j++) dp[i][j] = INF;
    dp[0][0] = 0;
    for (int i = 0; i < N; i++) {
        if (i > 0) {
            for (int j = 0; j <= N*N; j++) {
                dp[i][j] = min(dp[i][j], dp[i-1][j]);
            }
        }
        for (int j = 0; j <= N*N; j++) {
            if (j-(i+1) >= 0) dp[i][j] = min(dp[i][j], dp[i][j-i-1]+S[i]);
        }
    }
    cout << ans + dp[N-1][V] << endl;
    return 0;
}