ferinの競プロ帳

競プロについてのメモ

EDPC J - Sushi

問題ページ

解法

寿司がa[i]個乗っている皿がどこにあったとしても選ばれる確率に影響はない。したがってN要素の数列a[i]ではなくcnt[i]=(i個乗っている皿の数)と情報を持つことができる。
dfs(x, y, z) = (1個乗っている皿がx枚、2個乗っている皿がy枚、3個乗っている皿がz枚で全ての寿司が無くなる回数の期待値) としてメモ化再帰をする。1個乗っている皿が選ばれる確率がx/n、2個乗っている皿が選ばれる確率がy/n、3個乗っている皿が選ばれる確率がz/n、寿司が乗っていない皿が選ばれる確率が(n-x-y-z)/nである。寿司が乗っていない皿が選ばれた場合にdfs(x,y,z)と素直にdfs関数を呼び出すと無限ループになってしまう。これを回避するため遷移の式以下のように変形する。
dfs(x, y, z) = dfs(x-1, y, z) * x/n + dfs(x+1, y-1, z) * y/n + dfs(x, y+1, z-1) * z/n + dfs(x, y, z) * (n-x-y-z)/n + 1
(1-(n-x-y-z)/n) * dfs(x, y, z) = dfs(x-1, y, z) * x/n + dfs(x+1, y-1, z) * y/n + dfs(x, y+1, z-1) * z/n + 1
dfs(x, y, z) = dfs(x-1, y, z) * x/(x+y+z) + dfs(x+1, y-1, z) * y/(x+y+z) + dfs(x, y+1, z-1) * z/(x+y+z) + n/(x+y+z)
状態数がO(N^3)、遷移がO(1)なのでこの問題は解けた。

#include <bits/stdc++.h>
 
using namespace std;
using ll = long long;
// #define int ll
using PII = pair<ll, ll>;
 
#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()
 
template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }
 
template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); }
template<typename T,typename... Ts>
auto make_v(size_t a,Ts... ts) { 
  return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...));
}
template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type
fill_v(T &t, const V &v) { t=v; }
template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type
fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); }
 
template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')'; return out;
}
template<typename T>
istream& operator >> (istream& is, vector<T>& vec){
  for(T& x: vec) {is >> x;} return is;
}
template<class T>
ostream &operator <<(ostream& out,const vector<T>& a){
  out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out;
}
 
int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL
const int INF = 1<<30;
const ll LLINF = 1LL<<60;
const int MOD = 1000000007;

const double EPS = 1e-8;
bool used[305][305][305];
double dp[305][305][305];
ll n, cnt[4];

double dfs(ll x, ll y, ll z) {
  if(!x && !y && !z) return 0;
  if(used[x][y][z]) return dp[x][y][z];
  double ret = (double)n/(x+y+z);
  if(x) ret += dfs(x-1, y, z) * (double)x / (x+y+z);
  if(y) ret += dfs(x+1, y-1, z) * (double)y / (x+y+z);
  if(z) ret += dfs(x, y+1, z-1) * (double)z / (x+y+z);
  used[x][y][z] = true;
  return dp[x][y][z] = ret;
}

signed main(void)
{
  cin.tie(0);
  ios::sync_with_stdio(false);

  cin >> n;
  REP(i, n) {
    ll a;
    cin >> a;
    cnt[a]++;
  }

  cout << fixed << setprecision(15) << dfs(cnt[1], cnt[2], cnt[3]) << endl;

  return 0;
}