mayoko’s diary

プロコンとかいろいろ。

SRM 546 div1 med: FavouriteDigits

人の不幸で飯がうまい

問題

TopCoder Statistics - Problem Statement

数 n が与えられる。n 以上の数で, 以下の条件を満たす最小の整数を求めよ。

  • 10 進法で表したとき, その数には count1 個以上の digit1 が含まれる。
  • 10 進法で表したとき, その数には count2 個以上の digit2 が含まれる。
解法

桁DP + 二分探索で解けます。

calc(x) = (x 未満で digit1 を count1 以上, digit2 を count2 以上持つ整数の数) とすると, 求める整数は calc(x) - calc(n) = 1 を満たす整数です。なので calc(x) - calc(n) について二分探索をしましょう。

calc(x) は普通の桁 DP で求めることができます。

ll n;
int d1, d2, cnt1, cnt2;

// keta, cnt1, cnt2, small, leadingzero
ll dp[20][20][20][2][2];

// x 未満で d1 を cnt1 以上, d2 を cnt2 以上持つのを探す
ll calc(ll x) {
    string s = to_string(x);
    int n = s.size();
    memset(dp, 0, sizeof(dp));
    dp[0][0][0][0][0] = 1;
    for (int keta = 0; keta < n; keta++) for (int c1 = 0; c1 <= cnt1; c1++) for (int c2 = 0; c2 <= cnt2; c2++) {
        for (int small = 0; small < 2; small++) for (int lz = 0; lz < 2; lz++) {
            if (!dp[keta][c1][c2][small][lz]) continue;
            int num = s[keta] - '0';
            for (int i = 0; i < 10; i++) {
                if (!small && i > num) continue;
                int nc1 = c1, nc2 = c2;
                int nsmall = small, nlz = lz;
                if (i == d1) {
                    if (lz || d1 != 0) nc1++;
                }
                if (i == d2) nc2++;
                nc1 = min(nc1, cnt1);
                nc2 = min(nc2, cnt2);
                nsmall |= i < num;
                nlz |= (i != 0);
                dp[keta+1][nc1][nc2][nsmall][nlz] += dp[keta][c1][c2][small][lz];
            }
        }
    }
    return dp[n][cnt1][cnt2][1][1];
}

bool ok(ll x) {
    return calc(x) - calc(n) > 0;
}

class FavouriteDigits {
public:
    long long findNext(long long N, int digit1, int count1, int digit2, int count2) {
        n = N, d1 = digit1, d2 = digit2, cnt1 = count1, cnt2 = count2;
        if (d1 > d2) {
            swap(d1, d2);
            swap(cnt1, cnt2);
        }
        ll low = 0, high = 1ll<<60;
        while (high - low > 1) {
            const ll med = (low+high)/2;
            if (ok(med)) high = med;
            else low = med;
        }
        return low;
    }
};