mayoko’s diary

プロコンとかいろいろ。

SRM 674 div1 med:FindingKids

解法

合ってるとは言ってないですが最後の結構硬そうなサンプル通ってるので多分合ってると思います(珍しく #include とか省略しないで書いてるので誰か確かめてほC)。
提出できました。

まず蟻本を思い出すと, ぶつかって反射するのはどっちもそのまま直進してすれ違うのと同じです。
また, 左から k 番目の車に乗ってる子は, 常に左から k 番目になっている, ということに注目します。t 番目のクエリ kid[t] に対し, kid[t] が最初に左から k 番目にいることが分かっていれば, 各クエリでは, 左から k 番目の子の位置をソートして求めれば良いです。ただ, これを単純に実装すると O(nq log n) なのでまだ間に合いません。

そこで, ある位置 x を決めた時, 時間 time[t] の時点で座標が x 以下である人の数が k になるような, x の最小値を二分探索で求めます。これを求めるためには, 最初に左を向いてる人と最初に右を向いてる人を分けて, 座標について lower_bound すれば OK です。

const ll MOD = 1e9+7;

class FindingKids {
public:
    long long getSum(int n, int q, int A, int B, int C) {
        // about dir  0: right 1: left
        vector<int> dir(n), kid(q);
        vector<ll> pos(n), time(q);
        vector<pair<ll, int> > P;
        {
            ll a = A, b = B, c = C;
            set<ll> S;
            for (int i = 0; i < n; i++) {
                a = (a*b+c)%MOD;
                ll p = a % (MOD-n+i+1);
                if (S.count(p)) p = MOD-n+i;
                pos[i] = p;
                S.insert(p);
                P.emplace_back(p, i);
                if (p%2 == 0) dir[i] = 0;
                else dir[i] = 1;
            }
            for (int i = 0; i < q; i++) {
                a = (a*b+c)%MOD;
                kid[i] = (int)(a%n);
                a = (a*b+c)%MOD;
                time[i] = a;
            }
        }
        ll ans = 0;
        //lpos: 左に向かっていく
        //rpos: 右に向かっていく
        vector<ll> lpos, rpos;
        for (int i = 0; i < n; i++) {
            if (dir[i]) lpos.push_back(pos[i]);
            else rpos.push_back(pos[i]);
        }
        sort(lpos.begin(), lpos.end());
        sort(rpos.begin(), rpos.end());
        sort(P.begin(), P.end());
        vector<int> order(n);
        for (int i = 0; i < n; i++) {
            order[P[i].second] = i;
        }
        for (int t = 0; t < q; t++) {
            int o = order[kid[t]];
            ll low = -1ll<<32, high = 1ll<<32;
            while (high - low > 1) {
                ll med = (high+low)/2;
                int num = 0;
                num += lower_bound(lpos.begin(), lpos.end(), med+time[t]+1) - lpos.begin();
                num += lower_bound(rpos.begin(), rpos.end(), med-time[t]+1) - rpos.begin();
                if (num >= o+1) high = med;
                else low = med;
            }
            ans += abs(high);
        }
        return ans;
    }
};