mayoko’s diary

プロコンとかいろいろ。

yukicoder No.315 世界のなんとか3.5

解法

まずよくある桁 DP っぽく考えると,
dp[n][big][exist3][mod3][modP] = (n 桁目で元の数が今作ってる数より大きくなっているフラグが big で 3 が数字に含まれているフラグが exist3 で 3 で割った余りが mod3 で P で割った余りが modP であるような場合の数)
とすれば良さそうです。が, P の値がまぁまぁ大きいし桁の数も大きいので, これをそのままやると計算量が 10^9 くらいいってしまい死にます。

ところが P は8, 80, 800 に限定されています。例えば P = 8 のときを考えると, これは下 3 桁を見れば 8 の倍数になっているかがわかる, というのは有名な話です。
同様に, P=80, 800 のときも, 下 4 桁, 下 5 桁を見れば大丈夫です。

よって, n 桁の数について題意の数を調べたいとすると, n-5 桁まで dp[n][big][exist3][mod3] を調べて, それとは別に 5 桁分 dp[n][big][exist3][mod3][modP] を調べて, 結合してあげましょう。
イデアはそんな感じですが実装は結構重いです。頑張って下さい。

const ll MOD = 1e9+7;
const int MAXN = 200020;
ll dp[MAXN][2][2][3]; // n, big, exist3, mod3
ll memo1[10][2][2][3][800]; // n, big, exist3, mod3, modP
ll memo2[10][2][3][800]; // n, exist3, mod3, modP
ll memo[2][2][3]; // すでにbigと決定, exist3, mod3

int check(string A, int P) {
    int sum = 0, exist3 = 0, modP = 0;
    for (char c : A) {
        int num = c-'0';
        sum += num;
        exist3 |= (num==3);
        modP = (modP*10+num)%P;
    }
    if (modP) {
        if (sum%3 == 0 || exist3) return 1;
    }
    return 0;
}

ll calc(string s, int P) {
    if (s.size() <= 5) {
        int A = stoi(s);
        ll ans = 0;
        for (int i = 0; i <= A; i++) {
            if (check(to_string(i), P)) ans++;
        }
        return ans;
    }
    memset(memo1, 0, sizeof(memo1));
    memset(memo2, 0, sizeof(memo2));
    memset(memo, 0, sizeof(memo));
    memset(dp, 0, sizeof(dp));
    int n = s.size();
    string t = s.substr(n-5);
    // memo1 を求める
    memo1[0][0][0][0][0] = 1;
    for (int i = 0; i < 5; i++) for (int big = 0; big < 2; big++) {
        for (int e3 = 0; e3 < 2; e3++) for (int m3 = 0; m3 < 3; m3++) for (int mp = 0; mp < P; mp++) {
            if (big) {
                for (int j = 0; j < 10; j++) {
                    int nm3 = (m3+j)%3, nmp = (mp*10+j)%P;
                    int ne3 = e3;
                    ne3 |= (j==3);
                    memo1[i+1][big][ne3][nm3][nmp] += memo1[i][big][e3][m3][mp];
                    memo1[i+1][big][ne3][nm3][nmp] %= MOD;
                }
            } else {
                int num = t[i]-'0';
                for (int j = 0; j <= num; j++) {
                    int nm3 = (m3+j)%3, nmp = (mp*10+j)%P;
                    int ne3 = e3;
                    ne3 |= (j==3);
                    int nbig = (j<num);
                    memo1[i+1][nbig][ne3][nm3][nmp] += memo1[i][big][e3][m3][mp];
                    memo1[i+1][nbig][ne3][nm3][nmp] %= MOD;
                }
            }
        }
    }
    for (int e3 = 0; e3 < 2; e3++) for (int m3 = 0; m3 < 3; m3++) {
        for (int mod = 1; mod < P; mod++) for (int big = 0; big < 2; big++) (memo[0][e3][m3] += memo1[5][big][e3][m3][mod]) %= MOD;
    }
    // memo2 を求める
    memo2[0][0][0][0] = 1;
    for (int i = 0; i < 5; i++) for (int e3 = 0; e3 < 2; e3++) {
        for (int m3 = 0; m3 < 3; m3++) for (int mp = 0; mp < P; mp++) {
            for (int j = 0; j < 10; j++) {
                int nm3 = (m3+j)%3, nmp = (mp*10+j)%P;
                int ne3 = e3;
                ne3 |= (j==3);
                memo2[i+1][ne3][nm3][nmp] += memo2[i][e3][m3][mp];
                memo2[i+1][ne3][nm3][nmp] %= MOD;
            }
        }
    }
    for (int e3 = 0; e3 < 2; e3++) for (int m3 = 0; m3 < 3; m3++) {
        for (int mod = 1; mod < P; mod++) (memo[1][e3][m3] += memo2[5][e3][m3][mod]) %= MOD;
    }
    // dp を求める
    dp[0][0][0][0] = 1;
    for (int i = 0; i < n-5; i++) for (int big = 0; big < 2; big++) {
        for (int e3 = 0; e3 < 2; e3++) for (int m3 = 0; m3 < 3; m3++) {
            if (big) {
                for (int j = 0; j < 10; j++) {
                    int nm3 = (m3+j)%3;
                    int ne3 = e3;
                    ne3 |= (j==3);
                    dp[i+1][big][ne3][nm3] += dp[i][big][e3][m3];
                    dp[i+1][big][ne3][nm3] %= MOD;
                }
            } else {
                int num = s[i]-'0';
                for (int j = 0; j <= num; j++) {
                    int nm3 = (m3+j)%3;
                    int ne3 = e3;
                    ne3 |= (j==3);
                    int nbig = (j<num);
                    dp[i+1][nbig][ne3][nm3] += dp[i][big][e3][m3];
                    dp[i+1][nbig][ne3][nm3] %= MOD;
                }
            }
        }
    }
    ll ans = 0;
    for (int big = 0; big < 2; big++) {
        for (int e3 = 0; e3 < 2; e3++) {
            for (int m3 = 0; m3 < 3; m3++) {
                if (!e3) {
                    ans += dp[n-5][big][e3][m3] * memo[big][0][(3-m3)%3] % MOD;
                    for (int j = 0; j < 3; j++) {
                        ans += dp[n-5][big][e3][m3] * memo[big][1][j] % MOD;
                    }
                } else {
                    for (int i = 0; i < 2; i++) for (int j = 0; j < 3; j++) {
                        ans += dp[n-5][big][e3][m3] * memo[big][i][j] % MOD;
                    }
                }
                ans %= MOD;
            }
        }
    }
    return ans;
}

int main() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    string A, B;
    int P;
    cin >> A >> B >> P;
    ll ans = calc(B, P)-calc(A, P)+check(A, P);
    ans = (ans%MOD + MOD) % MOD;
    cout << ans << endl;
    return 0;
}