ferinの競プロ帳

競プロについてのメモ

Codeforces Round #439 (Div. 2) C. The Intriguing Obsession

問題ページ
Problem - C - Codeforces

概要

赤い島がa個、青い島がb個、紫の島がc個ある。これらの島に橋をかける。橋の長さを1としたとき、同じ色の島で最短距離が3未満になるような島のペアが存在しないようにしたい。このとき、橋をかける方法が何通りあるのかをmod 998244353で求めろ。

考えたこと

  1. 赤-青、青-紫、紫-赤の繋ぎ方が何通りなのかをそれぞれ考えてみる
  2. 島の数が1-1のときは2通り、1-2のときは3通り、2-2のときは7通りになっている
  3. これは組み合わせを頑張れば求められそうな気持ちになったので、求まったとして全体が何通りになるかを考える
  4. 最短距離が3以上の制約より赤-青の繋ぎ方が青-紫、紫-赤の繋ぎ方に影響を与えることはなさそう
  5. 独立に考えられそうなので(赤-青の通り数)×(青-紫の通り数)×(紫-赤の通り数)で全体が求まりそう
  6. それぞれの通り数を求める方法について考える
  7. dp[i][j] = (島の数がi, jの対になっているときの通り数)とする
  8. 青い島が2個、紫の島が3のときどうなってるか考えてみる
    f:id:ferin_tech:20171126182723p:plain
    青1と紫1をつないだとき→dp[1][2]通り
    青1と紫2をつないだとき→dp[1][2]通り
    青1と紫3をつないだとき→dp[1][2]通り
    青2と紫1をつないだとき→1通り(重複して数えないように青1と紫をつなぐ繋ぎ方は数えない)
    青2と紫2をつないだとき→1通り
    青2と紫3をつないだとき→1通り
    何も繋がないとき→1通り
    したがって(dp[1][2] + 1) * 3 + 1通りになる
  9. これを一般化するとdp[i][j] = (dp[i-1][j-1] + dp[i-2][j-1] + … + dp[1][j-1] + 1) * j + 1通りになる
  10. 素直にDPするとO(max(a, b, c)^3)だが累積和の部分を別に計算しておくとO(max(a, b, c)^2)で間に合う

modint構造体を使うと便利だけど定数倍が重いのでちょっとこわい

公式解説の方法

島の数がA-Bで橋の数がkのとき、C(A,k) * C(B, k) * k! で繋ぎ方が何通りあるか求まる。
よって島の数がA-Bのときはsum( k=0 to min(A, B), C(A,k) * C(B, k) * k!) で求められる。
これはO(max(a, b, c)^2)で計算できる。

どう考えてもこっちのほうが頭がいい。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)

template<unsigned MOD>
class ModInt {
public:
  unsigned x;
  ModInt(): x(0) { }
  ModInt(signed y) : x(y >= 0 ? y % MOD : MOD - (-y) % MOD) {}
  unsigned get() const { return x; }

  // 逆数
  ModInt inv() const {
    ll a = 1, p = x, e = MOD-2;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }
  // e乗
  ModInt pow(ll e) {
    ll a = 1, p = x;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }
  // 2のx乗
  ModInt pow2() {
    ll a = 1, p = 2, e = x;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }

  // Comparators
  bool operator <(ModInt b) { return x < b.x; }
  bool operator >(ModInt b) { return x > b.x; }
  bool operator<=(ModInt b) { return x <= b.x; }
  bool operator>=(ModInt b) { return x >= b.x; }
  bool operator!=(ModInt b) { return x != b.x; }
  bool operator==(ModInt b) { return x == b.x; }

  // increment, decrement
  ModInt operator++() { x++; return *this; }
  ModInt operator++(int) {ModInt a = *this; x++; return a;}
  ModInt operator--() { x--; return *this; }
  ModInt operator--(int) {ModInt a = *this; x--; return a;}

  // Basic Operations
  ModInt &operator+=(ModInt that) {
    x = ((ll)x+that.x)%MOD;
    return *this;
  }
  ModInt &operator-=(ModInt that) {
    x = ((((ll)x-that.x)%MOD)+MOD)%MOD;
    return *this;
  }
  ModInt &operator*=(ModInt that) {
    x = (ll)x * that.x % MOD;
    return *this;
  }
  // O(log(mod))かかるので注意
  ModInt &operator/=(ModInt that) {
    x = (ll)x * that.inv() % MOD;
    return *this;
  }
  ModInt &operator%=(ModInt that) {
    x = (ll)x % that.x;
    return *this;
  }
  ModInt operator+(ModInt that)const{return ModInt(*this) += that;}
  ModInt operator-(ModInt that)const{return ModInt(*this) -= that;}
  ModInt operator*(ModInt that)const{return ModInt(*this) *= that;}
  ModInt operator/(ModInt that)const{return ModInt(*this) /= that;}
  ModInt operator%(ModInt that)const{return ModInt(*this) %= that;}
};
typedef ModInt<998244353> mint;
// Input/Output
ostream &operator<<(ostream& os, mint a) { return os << a.x; }
istream &operator>>(istream& is, mint &a) { return is >> a.x; }

mint dp[5010][5010], dp2[5010][5010];
signed main(void)
{
  FOR(i, 1, 5001) {
    FOR(j, 1, 5001) {
      if(i == 1 || j == 1) dp[i][j] = max(i, j)+1;
      else dp[i][j] = (dp2[i-1][j-1]+1)*j + 1;
      dp2[i][j] = dp2[i-1][j] + dp[i][j];
    }
  }

  int a, b, c;
  cin >> a >> b >> c;
  cout << dp[a][b] * dp[b][c] * dp[a][c] << endl;

  return 0;
}