ferinの競プロ帳

競プロについてのメモ

エイシング プログラミング コンテスト 2019 E - Attack to a Tree

問題ページ

解法

木DPをする。dp[v][i][j] = (頂点vを根とする部分木でi回辺を切っていて頂点vの連結成分以外は問題文の条件を満たし、頂点vの連結成分が全てバッテリー(j=0)orそれ以外(j=1)のときの連結成分の頂点の重みの和の最小) とする。不可能な場合はinfをdpに持たせる。頂点vを根とする木とvの子cを根とする木をマージしていくことでdpの遷移を行う。頂点vとcの間の辺を切らない場合と切る場合に分けて考える。

切らない場合は各部分木でi回切断、j回切断している状態をマージするのでi+j回切断したときの情報が得られる。マージするどちらかのdpの値がinfであればそれは不可能な状態なので無視する。
new_dp[v][i+j][0] = dp[v][i][0] + dp[c][i][0]
new_dp[v][i+j][1] = dp[v][i][0] + dp[c][i][1]
new_dp[v][i+j][1] = dp[v][i][1] + dp[c][i][0]
new_dp[v][i+j][1] = dp[v][i][1] + dp[c][i][1]

切る場合はi+j+1回切断したときの情報が得られる。切っているので子cの部分木の頂点の重みは足す必要がない。辺を切ってvと連結でなくなる頂点の連結成分が問題文の条件を満たすときのみ辺を切断する遷移を行えることに注意。
new_dp[v][i+j+1][0] = dp[v][i][0]
new_dp[v][i+j+1][0] = dp[v][i][0]
new_dp[v][i+j+1][1] = dp[v][i][1]
new_dp[v][i+j+1][1] = dp[v][i][1]

以上のマージはO(N^2)でこれを各頂点について行うとO(N^3)に一見思えるが部分木のサイズまでしかループを回さないようにすることでO(N^2)となっている。二乗の木 DP - (iwi) { 反省します - TopCoder部
dp[root][i][0]=infかdp[root][i][1]<0であればi回切断が可能と判定でき最終的な答えが求められる。

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
// #define int ll
using PII = pair<ll, ll>;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()

template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }

template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); }
template<typename T,typename... Ts>
auto make_v(size_t a,Ts... ts) {
  return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...));
}
template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type
fill_v(T &t, const V &v) { t=v; }
template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type
fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); }

template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')'; return out;
}
template<typename T>
istream& operator >> (istream& is, vector<T>& vec){
  for(T& x: vec) {is >> x;} return is;
}
template<class T>
ostream &operator <<(ostream& out,const vector<T>& a){
  out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out;
}

int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL
const int INF = 1<<30;
const ll LLINF = 1LL<<60;
const ll MOD = 1000000007;

using vll = vector<vector<ll>>;
ll a[5010];
vector<ll> g[5010];
vll dfs(ll v, ll p) {
  // dp[v]
  vll ret(1, vector<ll>(2, LLINF));
  if(a[v]>0) ret[0][0] = a[v];
  ret[0][1] = a[v];
  for(auto to: g[v]) if(to != p) {
    // dp[c]
    auto vec = dfs(to, v);

    // new_dp[v]
    vll nret(ret.size() + vec.size(), vector<ll>(2, LLINF));
    REP(i, ret.size()) REP(j, vec.size()) {
      // vからiまでの辺を切断しない
      if(ret[i][0] != LLINF && vec[j][0] != LLINF) {
        chmin(nret[i+j][0], ret[i][0]+vec[j][0]);
      }
      if(ret[i][1] != LLINF && vec[j][1] != LLINF) {
        chmin(nret[i+j][1], ret[i][1]+vec[j][1]);
      }
      if(ret[i][0] != LLINF && vec[j][1] != LLINF) {
        chmin(nret[i+j][1], ret[i][0]+vec[j][1]);
      }
      if(ret[i][1] != LLINF && vec[j][0] != LLINF) {
        chmin(nret[i+j][1], ret[i][1]+vec[j][0]);
      }
      // vからiまでの辺を切断する
      if(ret[i][0] != LLINF && vec[j][0] != LLINF && i+j+1<nret.size()) {
        chmin(nret[i+j+1][0], ret[i][0]);
      }
      if(ret[i][0] != LLINF && vec[j][1] != LLINF && vec[j][1]<0 && i+j+1<nret.size()) {
        chmin(nret[i+j+1][0], ret[i][0]);
      }
      if(ret[i][1] != LLINF && vec[j][0] != LLINF && i+j+1<nret.size()) {
        chmin(nret[i+j+1][1], ret[i][1]);
      }
      if(ret[i][1] != LLINF && vec[j][1] != LLINF && vec[j][1]<0 && i+j+1<nret.size()) {
        chmin(nret[i+j+1][1], ret[i][1]);
      }
    }
    ret = nret;
  }
  return ret;
}

signed main(void) 
{
  cin.tie(0);
  ios::sync_with_stdio(false);

  ll n;
  cin >> n;
  REP(i, n) cin >> a[i];
  REP(i, n-1) {
    ll u, v;
    cin >> u >> v;
    u--, v--;
    g[u].push_back(v);
    g[v].push_back(u);
  }

  ll ret = INF;
  auto ans = dfs(0, -1);
  REP(i, ans.size()) REP(j, 2) {
    if(j==0 && ans[i][j] != LLINF) chmin(ret, i);
    if(j==1 && ans[i][j] < 0) chmin(ret, i);
  }
  cout << ret << endl;

  return 0;
}