AtCoder Regular Contest 051 D - 長方形
解法
rng さんの解説放送を聞いて解きました。やっぱり解説放送聞くとスッと理解しやすいですね。
rng さんが書いてたコード↓
arc051.contest.atcoder.jp
まずひとつのクエリを O(WH) で処理することを考えます。
長方形の形を w*h に固定します。a[i]+a[i+1]+...+a[i+w-1] の和が asum, b[j]+b[j+1]+...+b[j+h-1] の和が bsum であるとすると, マス目の総和は, asum*h+bsum*w で求めることが出来ます。なので,
- amax[i][j] = (インデックス i までで, j 個の連続する区間を選んだ時の a[k]+...+a[k+j-1] の最大値), bmax[i][j] を前計算しておく(これは O(H^2+W^2) で出来る)
- 各クエリ(X, Y)に対して, (amax[X][i]*j+bmax[Y][j]*i) をすべて試す(これは O(WH))
これで O(QWH) の解法が出来ました。AC する解法ではふたつ目の計算を高速化します。
とりあえず j を固定して考えることにします。すると, j と bmax[Y][j] は定数です。
i <-> amax[X][i] の対応を, (i, amax[X][i]) という二次元座標で表すことにします。i が x 座標, amax[X][i] が y 座標です。すると, amax[X][i]*j + bmax[Y][j]*i は (定数)*y + (定数)*x という式であることがわかります。
よって問題は 点群 (i, amax[X][i]) の中から, 上記の(定数)*y + (定数)*x の値を最大にするものを求める, というものになりました。
大学受験でよくあるやつですが, これを最大にするのは, 直線 y = -bmax[Y][j]/j*x + k の k を INF からだんだん下ろしていった時, 最初に点群とぶつかる点が答えです。
上から直線を下ろしていくので, 点群 (i, amax[X][i]) から凸包を作っておきます。で, 凸包の点の中でどの点に最初に当たるかは, 三分探索することが出来ます("凸"包なので)。ただ, 実際には衝突する点を調べるんじゃなくて, (定数)*y + (定数)*x を最大化するような三分探索をすれば OK です。
予め凸包を作っておけば各 j に対して O(log X) で最大値を得られるので, 時間内に答えを求められます。
const int MAX = 2222; const ll INF = 1ll<<55; ll a[MAX], b[MAX], asum[MAX], bsum[MAX]; ll amax[MAX][MAX], bmax[MAX][MAX]; // amax[最大index][長さ] = 最大値 struct P { ll x, y; P() {} P(ll x, ll y) : x(x), y(y) {} P operator+(P p) const {return P(x+p.x, y+p.y);} P operator-(P p) const {return P(x-p.x, y-p.y);} ll dot(P p) const {return x*p.x + y*p.y;} // 内積 ll det(P p) const {return x*p.y - y*p.x;} // 外積 bool operator<(const P& rhs) const { if (x != rhs.x) return x < rhs.x; return y < rhs.y; } }; // 凸包 vector<P> convex_hull(vector<P> ps) { int n = ps.size(); sort(ps.rbegin(), ps.rend()); int k = 0; // 凸包の頂点数 vector<P> qs(n*2); // 構築中の凸包 // 上側凸包の作成 for (int i = 0; i < n; i++) { while (k > 1 && (qs[k-1]-qs[k-2]).det(ps[i]-qs[k-1]) <= 0) k--; qs[k++] = ps[i]; } qs.resize(k); return qs; } ll query(int X, int Y) { ll ret = -INF; vector<P> ps; for (int i = 1; i <= X; i++) { ps.emplace_back(i, amax[X][i]); } auto qs = convex_hull(ps); int K = qs.size(); for (int j = 1; j <= Y; j++) { int low = 0, high = K-1; while (high-low > 2) { int d = (high-low)/3; int l = low+d, r = high-d; if (qs[l].y*j+bmax[Y][j]*qs[l].x > qs[r].y*j+bmax[Y][j]*qs[r].x) high = r; else low = l; } for (int i = low; i <= high; i++) ret = max(ret, qs[i].y*j + qs[i].x*bmax[Y][j]); } // for (int i = 1; i <= X; i++) for (int j = 1; j <= Y; j++) { // ret = max(ret, amax[X][i]*j+bmax[Y][j]*i); // } return ret; } int main() { cin.tie(0); ios::sync_with_stdio(false); for (int i = 0; i < MAX; i++) for (int j = 0; j < MAX; j++) { amax[i][j] = -INF; bmax[i][j] = -INF; } int W, H; cin >> W >> H; for (int i = 0; i < W; i++) cin >> a[i]; for (int i = 0; i < H; i++) cin >> b[i]; for (int i = 0; i < W; i++) asum[i+1] = asum[i]+a[i]; for (int i = 1; i <= W; i++) for (int j = 1; j <= W; j++) { amax[i][j] = max(amax[i][j], amax[i-1][j]); if (i-j >= 0) amax[i][j] = max(amax[i][j], asum[i]-asum[i-j]); } for (int i = 0; i < H; i++) bsum[i+1] = bsum[i]+b[i]; for (int i = 1; i <= H; i++) for (int j = 1; j <= H; j++) { bmax[i][j] = max(bmax[i][j], bmax[i-1][j]); if (i-j >= 0) bmax[i][j] = max(bmax[i][j], bsum[i]-bsum[i-j]); } int Q; cin >> Q; while (Q--) { int X, Y; cin >> X >> Y; cout << query(X, Y) << endl; } return 0; }