Codeforces Round #539 (Div. 2) F. Sasha and Interesting Fact from Graph Theory
解法
頂点a,bを結ぶパスの辺をe本と固定したときの組み合わせ数を求める。まずn-2個の頂点からパス上のe-1個の頂点を選ぶ方法がP(n-2, e-1)通りである。次にパス上の辺の重みを決定する方法について考える。mをe分割する方法はm個のボールの間(m-1個)にe-1個の仕切りを入れると考えることができC(m-1, e-1)通りである。
パス以外の辺の重みは自由に決められるのでm^(n-1-e)通りである。ラベルつきの森が何通りあるか求めるにはCayley's_formulaを用いればよい。頂点数n、連結成分数kの森はk*n^(n-1-k)通りである。今回の条件設定に当てはめると(e+1)*n^(n-2-e)通りである。e=n-1のときに注意。
まとめると辺数がeのときP(n-2,e-1)*C(m-1,e-1)*m^(n-1-e)*(e+1)*n^(n-2-e)通りとなる。これを1<=e<=n-1について足せばよい。
#include <bits/stdc++.h> using namespace std; using ll = long long; // #define int ll using PII = pair<ll, ll>; #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) #define ALL(x) x.begin(), x.end() template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); } template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); } template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; } template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); } template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); } template<typename T,typename... Ts> auto make_v(size_t a,Ts... ts) { return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...)); } template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type fill_v(T &t, const V &v) { t=v; } template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); } template<class S,class T> ostream &operator <<(ostream& out,const pair<S,T>& a){ out<<'('<<a.first<<','<<a.second<<')'; return out; } template<class T> ostream &operator <<(ostream& out,const vector<T>& a){ out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out; } int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL const int INF = 1<<30; const ll LLINF = 1LL<<60; const ll MOD = 1000000007; template<ll MOD> struct modint { ll x; modint(): x(0) {} modint(ll y) : x(y>=0 ? y%MOD : y%MOD+MOD) {} ll get() const { return x; } // e乗 modint pow(ll e) { ll a = 1, p = x; while(e > 0) { if(e%2 == 0) {p = (p*p) % MOD; e /= 2;} else {a = (a*p) % MOD; e--;} } return modint(a); } ll extgcd(ll a, ll b, ll &x, ll &y) { ll g = a; x = 1, y = 0; if(b != 0) g = extgcd(b, a%b, y, x), y -= (a/b) * x; return g; } modint inv() { ll s, t; extgcd(x, MOD, s, t); return modint(s); } // Comparators bool operator <(modint b) { return x < b.x; } bool operator >(modint b) { return x > b.x; } bool operator<=(modint b) { return x <= b.x; } bool operator>=(modint b) { return x >= b.x; } bool operator!=(modint b) { return x != b.x; } bool operator==(modint b) { return x == b.x; } // increment, decrement modint operator++() { x++; return *this; } modint operator++(signed) { modint t = *this; x++; return t; } modint operator--() { x--; return *this; } modint operator--(signed) { modint t = *this; x--; return t; } // Basic Operations modint &operator+=(modint that) { x += that.x; if(x >= MOD) x -= MOD; return *this; } modint &operator-=(modint that) { x -= that.x; if(x < 0) x += MOD; return *this; } modint &operator*=(modint that) { x = (ll)x * that.x % MOD; return *this; } modint &operator/=(modint that) { x = (ll)x * that.inv().x % MOD; return *this; } modint &operator%=(modint that) { x = (ll)x % that.x; return *this; } modint operator+(modint that) const { return *this += that; } modint operator-(modint that) const { return *this -= that; } modint operator*(modint that) const { return *this *= that; } modint operator/(modint that) const { return *this /= that; } modint operator%(modint that) const { return *this %= that; } }; using mint = modint<1000000007>; // Input/Output ostream &operator<<(ostream& os, mint a) { return os << a.x; } istream &operator>>(istream& is, mint &a) { return is >> a.x; } ll binpow(ll x, ll e) { ll ret = 1, p = x; while(e > 0) { if(e&1) {(ret *= p) %= MOD; e--;} else {(p *= p) %= MOD; e /= 2;} } return ret; } template<bool bigN=false> ll combi(ll N_, ll K_, ll mo=MOD) { const int NUM_=1e6+10; static ll fact[NUM_+1]={},factr[NUM_+1]={},inv[NUM_+1]={}; auto binpow = [&](ll x, ll e) -> ll { ll a = 1, p = x; while(e > 0) { if(e%2 == 0) {p = (p*p) % mo; e /= 2;} else {a = (a*p) % mo; e--;} } return a; }; if (fact[0]==0) { fact[0] = factr[0] = inv[0] = 1; FOR(i, 1, NUM_+1) fact[i] = fact[i-1] * i % MOD; factr[NUM_] = binpow(fact[NUM_], mo-2); for(int i=NUM_-1; i>=0; --i) factr[i] = factr[i+1] * (i+1) % MOD; // bigNがないならいらない // REP(i, NUM_+1) inv[i] = binpow(i, MOD-2); } if(K_<0 || K_>N_) return 0; // 前計算 O(max(N,K)) クエリ O(1) if(!bigN) return factr[K_]*fact[N_]%MOD*factr[N_-K_]%MOD; // Nが大きいけどKが小さい場合に使う 前計算 O(Klog(mod)) クエリ O(K) ll ret = 1; for(;K_>0;N_--,K_--) (ret *= N_%MOD) %= MOD, (ret *= inv[K_]) %= MOD; return ret; } signed main(void) { cin.tie(0); ios::sync_with_stdio(false); ll n, m, a, b; cin >> n >> m >> a >> b; mint ret = 0, fact = 1; FOR(e, 1, n) { mint tmp = fact; tmp *= combi(m-1, e-1); tmp *= binpow(m, n-1-e); if(e != n-1) { tmp *= e+1; tmp *= binpow(n, n-2-e); } fact *= n-e-1; ret += tmp; } cout << ret << endl; return 0; }