Codeforces Round #439 (Div. 2) C. The Intriguing Obsession
問題ページ
Problem - C - Codeforces
概要
赤い島がa個、青い島がb個、紫の島がc個ある。これらの島に橋をかける。橋の長さを1としたとき、同じ色の島で最短距離が3未満になるような島のペアが存在しないようにしたい。このとき、橋をかける方法が何通りあるのかをmod 998244353で求めろ。
考えたこと
- 赤-青、青-紫、紫-赤の繋ぎ方が何通りなのかをそれぞれ考えてみる
- 島の数が1-1のときは2通り、1-2のときは3通り、2-2のときは7通りになっている
- これは組み合わせを頑張れば求められそうな気持ちになったので、求まったとして全体が何通りになるかを考える
- 最短距離が3以上の制約より赤-青の繋ぎ方が青-紫、紫-赤の繋ぎ方に影響を与えることはなさそう
- 独立に考えられそうなので(赤-青の通り数)×(青-紫の通り数)×(紫-赤の通り数)で全体が求まりそう
- それぞれの通り数を求める方法について考える
- dp[i][j] = (島の数がi, jの対になっているときの通り数)とする
- 青い島が2個、紫の島が3のときどうなってるか考えてみる
青1と紫1をつないだとき→dp[1][2]通り
青1と紫2をつないだとき→dp[1][2]通り
青1と紫3をつないだとき→dp[1][2]通り
青2と紫1をつないだとき→1通り(重複して数えないように青1と紫をつなぐ繋ぎ方は数えない)
青2と紫2をつないだとき→1通り
青2と紫3をつないだとき→1通り
何も繋がないとき→1通り
したがって(dp[1][2] + 1) * 3 + 1通りになる - これを一般化するとdp[i][j] = (dp[i-1][j-1] + dp[i-2][j-1] + … + dp[1][j-1] + 1) * j + 1通りになる
- 素直にDPするとO(max(a, b, c)^3)だが累積和の部分を別に計算しておくとO(max(a, b, c)^2)で間に合う
modint構造体を使うと便利だけど定数倍が重いのでちょっとこわい
公式解説の方法
島の数がA-Bで橋の数がkのとき、C(A,k) * C(B, k) * k! で繋ぎ方が何通りあるか求まる。
よって島の数がA-Bのときはsum( k=0 to min(A, B), C(A,k) * C(B, k) * k!) で求められる。
これはO(max(a, b, c)^2)で計算できる。
どう考えてもこっちのほうが頭がいい。
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) template<unsigned MOD> class ModInt { public: unsigned x; ModInt(): x(0) { } ModInt(signed y) : x(y >= 0 ? y % MOD : MOD - (-y) % MOD) {} unsigned get() const { return x; } // 逆数 ModInt inv() const { ll a = 1, p = x, e = MOD-2; while(e > 0) { if(e%2 == 0) {p = (p*p) % MOD; e /= 2;} else {a = (a*p) % MOD; e--;} } a %= MOD; return ModInt(a); } // e乗 ModInt pow(ll e) { ll a = 1, p = x; while(e > 0) { if(e%2 == 0) {p = (p*p) % MOD; e /= 2;} else {a = (a*p) % MOD; e--;} } a %= MOD; return ModInt(a); } // 2のx乗 ModInt pow2() { ll a = 1, p = 2, e = x; while(e > 0) { if(e%2 == 0) {p = (p*p) % MOD; e /= 2;} else {a = (a*p) % MOD; e--;} } a %= MOD; return ModInt(a); } // Comparators bool operator <(ModInt b) { return x < b.x; } bool operator >(ModInt b) { return x > b.x; } bool operator<=(ModInt b) { return x <= b.x; } bool operator>=(ModInt b) { return x >= b.x; } bool operator!=(ModInt b) { return x != b.x; } bool operator==(ModInt b) { return x == b.x; } // increment, decrement ModInt operator++() { x++; return *this; } ModInt operator++(int) {ModInt a = *this; x++; return a;} ModInt operator--() { x--; return *this; } ModInt operator--(int) {ModInt a = *this; x--; return a;} // Basic Operations ModInt &operator+=(ModInt that) { x = ((ll)x+that.x)%MOD; return *this; } ModInt &operator-=(ModInt that) { x = ((((ll)x-that.x)%MOD)+MOD)%MOD; return *this; } ModInt &operator*=(ModInt that) { x = (ll)x * that.x % MOD; return *this; } // O(log(mod))かかるので注意 ModInt &operator/=(ModInt that) { x = (ll)x * that.inv() % MOD; return *this; } ModInt &operator%=(ModInt that) { x = (ll)x % that.x; return *this; } ModInt operator+(ModInt that)const{return ModInt(*this) += that;} ModInt operator-(ModInt that)const{return ModInt(*this) -= that;} ModInt operator*(ModInt that)const{return ModInt(*this) *= that;} ModInt operator/(ModInt that)const{return ModInt(*this) /= that;} ModInt operator%(ModInt that)const{return ModInt(*this) %= that;} }; typedef ModInt<998244353> mint; // Input/Output ostream &operator<<(ostream& os, mint a) { return os << a.x; } istream &operator>>(istream& is, mint &a) { return is >> a.x; } mint dp[5010][5010], dp2[5010][5010]; signed main(void) { FOR(i, 1, 5001) { FOR(j, 1, 5001) { if(i == 1 || j == 1) dp[i][j] = max(i, j)+1; else dp[i][j] = (dp2[i-1][j-1]+1)*j + 1; dp2[i][j] = dp2[i-1][j] + dp[i][j]; } } int a, b, c; cin >> a >> b >> c; cout << dp[a][b] * dp[b][c] * dp[a][c] << endl; return 0; }