mayoko’s diary

プロコンとかいろいろ。

JAG 夏合宿 Day2 A - Parades

問題

jag2016summer-day2.contest.atcoder.jp

N 頂点からなる木がある。各頂点の次数はたかだか 10 である。

このグラフ上でいくつかのパレードを開きたい。パレードの候補は M 個あり, その各パレードは頂点 u から 頂点 v へのパスで構成される。
これらのパレードは同時開催するため, 開くパレードはすべてパスの辺を共有していてはならない。最高でいくつのパレードを開けるかを求めよ。

解法

藤原さんの解答を参考にしました。

基本的には木 DP で, dp[v][s] = (頂点 v から伸びている辺のうち, s で表される集合の辺は使ったときに開催できるパレードの数の最大値) とします。

パレードはパスで表されるので, u, v の lca で特徴づけることができます。u -> lca -> v と移動する場合, lca では u 側に降りるために使う辺と v 側に降りるために使う辺を消費します。この辺が lca にとって x 番目, y 番目の辺であった場合,
dp[lca][s|(1< lca -> v みたいなパスが候補になったとして, そのパス自体でパレードの数は +1 されますが, そのパス上で行われていたパレードの数がどうなるかを考慮する必要があります。これをいちいち計算していると O(nm) 以上の計算量がかかりそうなのでアウトです。

ということで工夫をする必要があるわけですが, この木 dp は明らかに葉のほうから先にやっていきます。「末端が u であるようなパス」というのは lca -> u というように書けるわけですが, この lca はこれから先どんどん上に伸びていくだけなので, 「今調べている段階で u を終着点とするようなパスでどれだけのパレードを開けているか」というようなものを考えれば良いことになります。下のコードでこれをやっているのは S[v] というやつです。今までのパスの上に一つ頂点がつくことによってパレードの回数が更新される, という感じです。

あと問題なのは 各頂点に対して, その頂点を lca とするものは最高で O(n) 個考えられ, bitDP やってるときにいちいちやっていると O(n^2 2^10) かかって死ぬ, という問題があります。しかしあらかじめ M[x][y] = (頂点 lca で x 本目, y 本目の辺を使う場合に最大で開けるパレードの数) を前計算しておくとこれは大丈夫になります。

int main() {
	cin.tie(0);
	ios::sync_with_stdio(false);
	int T;
	cin >> T;
	while (T--) {
		int N;
		cin >> N;
		vector<vi> G(N);
		for (int i = 0; i < N-1; i++) {
			int a, b;
			cin >> a >> b;
			a--; b--;
			G[a].push_back(b);
			G[b].push_back(a);
		}
		// 木についていろいろメモ
		// 頂点を調べる順番
		vector<int> T(N);
		// 子
		vector<vi> chs(N);
		// 親
		vector<int> par(N);
		// 深さ
		vector<int> d(N);
		{
			d[0] = 1;
			par[0] = -1;
			int k = 0;
			queue<int> que;
			que.push(0);
			while (!que.empty()) {
				int now = que.front(); que.pop();
				T[k++] = now;
				for (int ch : G[now]) {
					if (!d[ch]) {
						par[ch] = now;
						chs[now].push_back(ch);
						d[ch] = d[now] + 1;
						que.push(ch);
					}
				}
			}
		}
		// anc[v][u] = v, u の LCA
		vector<vi> anc(N, vi(N));
		for (int i = 0; i < N; i++) {
			int v = T[i];
			for (int j = 0; j < N; j++) {
				int u = T[j];
				if (v == u) {
					anc[v][u] = v;
				}
				else if (d[v] > d[u]) {
					anc[v][u] = anc[par[v]][u];
				}
				else if (d[v] < d[u]) {
					anc[v][u] = anc[v][par[u]];
				}
				else if (d[v] > 1) {
					anc[v][u] = anc[par[v]][u];
				}
				else {
					anc[v][u] = 0;
				}
			}
		}
		vector<vi> dir(N, vi(N));
		for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) {
			int v = T[i], u = T[j];
			if (v != u && anc[v][u] == v) {
				for (int x = 0; x < chs[v].size(); x++) {
					int w = chs[v][x];
					if (anc[w][u] == w) {
						dir[v][u] = x;
						break;
					}
				}
			}
		}
		vector<vector<pii> > V(N);
		int M;
		cin >> M;
		for (int i = 0; i < M; i++) {
			int a, b;
			cin >> a >> b;
			a--; b--;
			int lca = anc[a][b];
			V[lca].emplace_back(a, b);
		}
		// dp
		vector<vi> dp(N, vi(1<<10));
		vi S(N);
		for (int t = N-1; t >= 0; t--) {
			int v = T[t];
			vi M1(10);
			vector<vi> M(10, vi(10));
			for (pii des : V[v]) {
				int a = des.first, b = des.second;
				if (a == v) {
					int x = dir[v][b];
					M1[x] = max(M1[x], S[b]+1);
				}
				else if (b == v) {
					int x = dir[v][a];
					M1[x] = max(M1[x], S[a]+1);
				}
				else {
					int x = dir[v][a], y = dir[v][b];
					M[x][y] = max(M[x][y], S[a] + S[b] + 1);
				}
			}
			int G = 0;
			for (int ch : chs[v]) {
				G += dp[ch][(1<<chs[ch].size())-1];
			}
			dp[v][0] = G;
			int sz = chs[v].size();
			for (int s = 0; s < 1<<sz; s++) {
				for (int x = 0; x < sz; x++) {
					if ((s>>x)&1) continue;
					dp[v][s|(1<<x)] = max(dp[v][s|(1<<x)], dp[v][s]);
					int ch = chs[v][x];
					dp[v][s|(1<<x)] = max(dp[v][s|(1<<x)], dp[v][s] + M1[x] - dp[ch][(1<<chs[ch].size())-1]);
					for (int y = 0; y < sz; y++) if (x != y && ((s>>y)&1) == 0) {
						int ch2 = chs[v][y];
						dp[v][s|(1<<x)|(1<<y)] = max(dp[v][s|(1<<x)|(1<<y)], dp[v][s] + M[x][y] - dp[ch][(1<<chs[ch].size())-1] - dp[ch2][(1<<chs[ch2].size())-1]);
					}
				}
			}
			int All = (1<<sz)-1;
			for (int des = 0; des < N; des++) {
				if (des != v && anc[v][des] == v) {
					int x = dir[v][des];
					int ch = chs[v][x];
					S[des] += dp[v][All^(1<<x)] - dp[ch][(1<<chs[ch].size())-1];
				}
			}
			S[v] = dp[v][(1<<sz)-1];
		}
		cout << dp[0][(1<<chs[0].size())-1] << endl;
	}
	return 0;
}