mayoko’s diary

プロコンとかいろいろ。

SRM648div1med:KitayutaMart

問題:http://community.topcoder.com/stat?c=problem_statement&pm=13649&rd=16312

解法:まず基本として、値段が小さいものから順に買っていくのが最適である。ということで、値段の小さい順に選んでいった時、N番目のものは何かを求めれば良い。
KやNはかなりでかいので、普通に数えていっても間に合わない。そこで、まずアバウトに小さい方から数えていくことを考える。具体的には、以下のようなことを考える。

①区間[1,K]という範囲から、適当な操作をして区間[K+1, 2*K]に値段を変える
②区間[K+1,2*K]という区間に変えると、あとは区間[2*K+1,4*K],[4*K+1,8*K],...というように区間を変えていくのは容易(それぞれの数を1回ずつ2倍すれば良い)。それぞれの区間から次の区間に変えるために掛け算する回数はK回

以上のことを考えながら解法を作る。
まず①にするのにかかる回数を考える。区間[K+1,2*K]にするには、
①まず1〜Kを2倍する
②(元の数が)1〜K/2を2倍する
③(元の数が)1〜K/4を2倍する...
と操作をしていけば良いのでかかる回数は,K+K/2+K/4+...となる。
Nがこの回数(numとする)以下の時は、解は区間[1,2*K]にあることになる。ここで、f(x)を、「値段がx以下になるような商品の総数」とすれば、求める数は、f(x) >= Nかつf(x-1) < Nを満たす数になる。このf(x)は①にかかる回数を求めるのと同様に計算できて、f(x)=min(x,K)+min(x,K/2)+...となる。よって、求めるxの値を2分探索すれば良い。

Nがnumより大きい時は、上の②の事実を使う。M=N-numとすると、2を掛け算する回数はp=(M-1)/K回。また、区間[K+1,2*K]のうちで小さいものから数えてM-p*K番目のものを求めなければならないが、これは上と同様に二分探索で計算することができる。

以下ソースコード

ll getNum(int K) {
    ll ret = 0;
    while (K) {
        ret += K;
        K /= 2;
    }
    return ret;
}

ll get(ll x, ll K) {
    ll ret = 0;
    while (x) {
        ret += min(K, x);
        x /= 2;
    }
    return ret;
}

const ll MOD = 1e9+7;

// x^p
ll pow_mod(ll x, ll p, ll m = MOD) {
    if (x == 0) return 0;
    if (p == 0) return 1;
    if (p == 1) return x;
    if (p % 2 == 0) {
        ll tmp = pow_mod(x, p/2, m);
        return (tmp * tmp) % m;
    } else {
        return (x*pow_mod(x, p-1, m)) % m;
    }
}

class KitayutaMart {
public:
    int lastPrice(int N, int K) {
        ll num = getNum(K);
        cout << num << endl;
        if (num >= (ll)N) {
            ll low = 0, high = 2*K;
            while (high - low > 1) {
                ll med = (high + low) / 2;
                if (get(med, K) >= N) high = med;
                else low = med;
            }
            return (int)high;
        }
        // N is big
        int M = N-num;
        ll p = (M-1) / K; // 2をかける回数
        ll q = M - p*K;
        ll low = K, high = 2*K;
        while (high - low > 1) {
            ll med = (high + low) / 2;
            if (get(med, K)-get(K, K) >= q) high = med;
            else low = med;
        }
        printf("high is %lld, p is %lld\n", high, p);
        return (pow_mod(2,p)*high) % MOD;
    }
};