ARC086 E - Smuggling Marbles
解法
木DPを行う。dp[i][j][k]=(j回動かしたときに頂点iにk個(=0,1か2個以上)ビー玉が存在するような初期配置の数) と定義してDPをする。頂点iの子をc_0とc_1とする。これらの子の情報dp[c_0]とdp[c_1]をマージする方法について考える。j=0~d(頂点iの高さ)について以下の計算を行う。これらが1回以上移動したときの組み合わせ数の計算結果となる。1回も動かさないときの組み合わせ数のdp[i][j][0]=dp[i][j][1]=1, dp[i][j][2]=0を先頭に挿入することで頂点iについての結果を得ることができる。
- dp[i][j][0] = dp[c_0][j][0] * dp[c_1][j][0]
- dp[i][j][1] = dp[c_0][j][1] * dp[c_1][j][0] + dp[c_0][j][0] * dp[c_0][j][1]
- dp[i][j][2] = dp[c_0][j][2] * (dp[c_1][j][0] + dp[c_1][j][1] + dp[c_1][j][0]) + dp[c_0][j][1] * (dp[c_1][j][1] + dp[c_1][j][2]) + dp[c_0][j][0] * dp[c_1][j][2]
移動した後に頂点に2個以上存在している場合はそのビー玉を取り除く。したがって上記の処理を行ったあとに dp[i][j][0] += dp[i][j][2], dp[i][j][2] = 0 とする。
このマージを普通に書くと遷移にO(N)かかり全体でO(N2)かかってTLEしてしまう。この高速化にデータ構造をマージする一般的テクを用いる。子c_0とc_1をマージするときにサイズが小さい方から大きい方へ向かってマージするように書くとマージテクの要領でO(NlogN)のように考えられる。公式解説にあるようにLCAを考えるとO(N)となっていることがわかる。実装上の注意点として2個以上存在しているビー玉を取り除く場所でループする範囲を必要最小限となるように取らないと、その部分で計算量が増えてTLEしてしまうので要注意。
2乗の木DPみたいな雰囲気を感じる
CSAにuniform treeという類題があります
#include <bits/stdc++.h> using namespace std; using ll = long long; // #define int ll using PII = pair<ll, ll>; #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) #define ALL(x) x.begin(), x.end() template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); } template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); } template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; } template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); } template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); } template<typename T,typename... Ts> auto make_v(size_t a,Ts... ts) { return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...)); } template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type fill_v(T &t, const V &v) { t=v; } template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); } template<class S,class T> ostream &operator <<(ostream& out,const pair<S,T>& a){ out<<'('<<a.first<<','<<a.second<<')'; return out; } template<typename T> istream& operator >> (istream& is, vector<T>& vec){ for(T& x: vec) {is >> x;} return is; } template<class T> ostream &operator <<(ostream& out,const vector<T>& a){ out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out; } int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL const int INF = 1<<30; const ll LLINF = 1LL<<60; const ll MOD = 1000000007; signed main(void) { cin.tie(0); ios::sync_with_stdio(false); ll n; cin >> n; ++n; vector<vector<ll>> g(n); REP(i, n-1) { ll p; cin >> p; g[p].push_back(i+1); } vector<ll> depth(n); function<void(ll,ll)> dep = [&](ll v, ll d) { depth[d]++; for(auto to: g[v]) { dep(to, d+1); } }; struct node { // 頂点に0,1,2個ビー玉が存在するときについて ll x, y, z; node() {} node(ll x, ll y, ll z) : x(x), y(y), z(z) {} // dpの情報のマージ node plus(node n) { ll nx = x*n.x % MOD; ll ny = (y*n.x % MOD + x*n.y % MOD) % MOD; ll nz = (z*(n.x+n.y+n.z) % MOD + y*(n.y+n.z) % MOD + x*n.z % MOD) % MOD; return node(nx, ny, nz); } }; function<deque<node>(ll)> dfs = [&](ll v) { // ret[i] = (頂点vでi回移動したときの情報) deque<node> ret; ll sz = 0; for(auto to: g[v]) { auto dq = dfs(to); if(ret.size() < dq.size()) swap(ret, dq); // dq -> retへマージ REP(i, dq.size()) ret[i] = ret[i].plus(dq[i]); chmax(sz, (ll)dq.size()); } // ret.size()まで回すのではなくmax(dq.size())までに抑える REP(i, sz) { (ret[i].x += ret[i].z) %= MOD; ret[i].z = 0; } ret.push_front(node(1, 1, 0)); return ret; }; dep(0, 0); auto ret = dfs(0); vector<ll> pow2(n+1); pow2[0] = 1; FOR(i, 1, n+1) pow2[i] = pow2[i-1] * 2 % MOD; ll d = 0, ans = 0; for(auto i: ret) (ans += i.y * pow2[n - depth[d++]] % MOD) %= MOD; cout << ans << endl; return 0; }