ferinの競プロ帳

競プロについてのメモ

ARC086 E - Smuggling Marbles

問題ページ

解法

木DPを行う。dp[i][j][k]=(j回動かしたときに頂点iにk個(=0,1か2個以上)ビー玉が存在するような初期配置の数) と定義してDPをする。頂点iの子をc_0とc_1とする。これらの子の情報dp[c_0]とdp[c_1]をマージする方法について考える。j=0~d(頂点iの高さ)について以下の計算を行う。これらが1回以上移動したときの組み合わせ数の計算結果となる。1回も動かさないときの組み合わせ数のdp[i][j][0]=dp[i][j][1]=1, dp[i][j][2]=0を先頭に挿入することで頂点iについての結果を得ることができる。

  • dp[i][j][0] = dp[c_0][j][0] * dp[c_1][j][0]
  • dp[i][j][1] = dp[c_0][j][1] * dp[c_1][j][0] + dp[c_0][j][0] * dp[c_0][j][1]
  • dp[i][j][2] = dp[c_0][j][2] * (dp[c_1][j][0] + dp[c_1][j][1] + dp[c_1][j][0]) + dp[c_0][j][1] * (dp[c_1][j][1] + dp[c_1][j][2]) + dp[c_0][j][0] * dp[c_1][j][2]

移動した後に頂点に2個以上存在している場合はそのビー玉を取り除く。したがって上記の処理を行ったあとに dp[i][j][0] += dp[i][j][2], dp[i][j][2] = 0 とする。
このマージを普通に書くと遷移にO(N)かかり全体でO(N2)かかってTLEしてしまう。この高速化にデータ構造をマージする一般的テクを用いる。子c_0とc_1をマージするときにサイズが小さい方から大きい方へ向かってマージするように書くとマージテクの要領でO(NlogN)のように考えられる。公式解説にあるようにLCAを考えるとO(N)となっていることがわかる。実装上の注意点として2個以上存在しているビー玉を取り除く場所でループする範囲を必要最小限となるように取らないと、その部分で計算量が増えてTLEしてしまうので要注意。

2乗の木DPみたいな雰囲気を感じる
CSAにuniform treeという類題があります

#include <bits/stdc++.h>
 
using namespace std;
using ll = long long;
// #define int ll
using PII = pair<ll, ll>;
 
#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()
 
template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }
 
template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); }
template<typename T,typename... Ts>
auto make_v(size_t a,Ts... ts) { 
  return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...));
}
template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type
fill_v(T &t, const V &v) { t=v; }
template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type
fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); }
 
template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')'; return out;
}
template<typename T>
istream& operator >> (istream& is, vector<T>& vec){
  for(T& x: vec) {is >> x;} return is;
}
template<class T>
ostream &operator <<(ostream& out,const vector<T>& a){
  out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out;
}
 
int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL
const int INF = 1<<30;
const ll LLINF = 1LL<<60;
const ll MOD = 1000000007;

signed main(void)
{
  cin.tie(0);
  ios::sync_with_stdio(false);

  ll n;
  cin >> n;
  ++n;
  vector<vector<ll>> g(n);
  REP(i, n-1) {
    ll p;
    cin >> p;
    g[p].push_back(i+1);
  }

  vector<ll> depth(n);
  function<void(ll,ll)> dep = [&](ll v, ll d) {
    depth[d]++;
    for(auto to: g[v]) {
      dep(to, d+1);
    }
  };

  struct node {
    // 頂点に0,1,2個ビー玉が存在するときについて
    ll x, y, z;
    node() {}
    node(ll x, ll y, ll z) : x(x), y(y), z(z) {}
    // dpの情報のマージ
    node plus(node n) {
      ll nx = x*n.x % MOD;
      ll ny = (y*n.x % MOD + x*n.y % MOD) % MOD;
      ll nz = (z*(n.x+n.y+n.z) % MOD + y*(n.y+n.z) % MOD + x*n.z % MOD) % MOD;
      return node(nx, ny, nz);
    }
  };

  function<deque<node>(ll)> dfs = [&](ll v) {
    // ret[i] = (頂点vでi回移動したときの情報)
    deque<node> ret;
    ll sz = 0;
    for(auto to: g[v]) {
      auto dq = dfs(to);
      if(ret.size() < dq.size()) swap(ret, dq);
      // dq -> retへマージ
      REP(i, dq.size()) ret[i] = ret[i].plus(dq[i]);
      chmax(sz, (ll)dq.size());
    }
    // ret.size()まで回すのではなくmax(dq.size())までに抑える
    REP(i, sz) {
      (ret[i].x += ret[i].z) %= MOD;
      ret[i].z = 0;
    }
    ret.push_front(node(1, 1, 0));
    return ret;
  };

  dep(0, 0);
  auto ret = dfs(0);

  vector<ll> pow2(n+1);
  pow2[0] = 1;
  FOR(i, 1, n+1) pow2[i] = pow2[i-1] * 2 % MOD;

  ll d = 0, ans = 0;
  for(auto i: ret) (ans += i.y * pow2[n - depth[d++]] % MOD) %= MOD;
  cout << ans << endl;

  return 0;
}