ferinの競プロ帳

競プロについてのメモ

Codeforces Round #539 (Div. 2) F. Sasha and Interesting Fact from Graph Theory

問題ページ

解法

頂点a,bを結ぶパスの辺をe本と固定したときの組み合わせ数を求める。まずn-2個の頂点からパス上のe-1個の頂点を選ぶ方法がP(n-2, e-1)通りである。次にパス上の辺の重みを決定する方法について考える。mをe分割する方法はm個のボールの間(m-1個)にe-1個の仕切りを入れると考えることができC(m-1, e-1)通りである。
パス以外の辺の重みは自由に決められるのでm^(n-1-e)通りである。ラベルつきの森が何通りあるか求めるにはCayley's_formulaを用いればよい。頂点数n、連結成分数kの森はk*n^(n-1-k)通りである。今回の条件設定に当てはめると(e+1)*n^(n-2-e)通りである。e=n-1のときに注意。
まとめると辺数がeのときP(n-2,e-1)*C(m-1,e-1)*m^(n-1-e)*(e+1)*n^(n-2-e)通りとなる。これを1<=e<=n-1について足せばよい。

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
// #define int ll
using PII = pair<ll, ll>;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()

template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }

template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); }
template<typename T,typename... Ts>
auto make_v(size_t a,Ts... ts) { 
    return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...));
}
template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type
fill_v(T &t, const V &v) { t=v; }
template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type
fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); }

template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
    out<<'('<<a.first<<','<<a.second<<')'; return out;
}
template<class T>
ostream &operator <<(ostream& out,const vector<T>& a){
    out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out;
}

int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL
const int INF = 1<<30;
const ll LLINF = 1LL<<60;
const ll MOD = 1000000007;

template<ll MOD>
struct modint {
    ll x;
    modint(): x(0) {}
    modint(ll y) : x(y>=0 ? y%MOD : y%MOD+MOD) {}
    ll get() const { return x; }
    // e乗
    modint pow(ll e) {
        ll a = 1, p = x;
        while(e > 0) {
            if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
            else {a = (a*p) % MOD; e--;}
        }
        return modint(a);
    }
    ll extgcd(ll a, ll b, ll &x, ll &y) {
        ll g = a; x = 1, y = 0;
        if(b != 0) g = extgcd(b, a%b, y, x), y -= (a/b) * x;
        return g;
    }
    modint inv() {
        ll s, t; extgcd(x, MOD, s, t);
        return modint(s);
    }
    // Comparators
    bool operator <(modint b) { return x < b.x; }
    bool operator >(modint b) { return x > b.x; }
    bool operator<=(modint b) { return x <= b.x; }
    bool operator>=(modint b) { return x >= b.x; }
    bool operator!=(modint b) { return x != b.x; }
    bool operator==(modint b) { return x == b.x; }
    // increment, decrement
    modint operator++() { x++; return *this; }
    modint operator++(signed) { modint t = *this; x++; return t; }
    modint operator--() { x--; return *this; }
    modint operator--(signed) { modint t = *this; x--; return t; }
    // Basic Operations
    modint &operator+=(modint that) {
        x += that.x;
        if(x >= MOD) x -= MOD;
        return *this;
    }
    modint &operator-=(modint that) {
        x -= that.x;
        if(x < 0) x += MOD;
        return *this;
    }
    modint &operator*=(modint that) {
        x = (ll)x * that.x % MOD;
        return *this;
    }
    modint &operator/=(modint that) {
        x = (ll)x * that.inv().x % MOD;
        return *this;
    }
    modint &operator%=(modint that) {
        x = (ll)x % that.x;
        return *this;
    }
    modint operator+(modint that) const { return *this += that; }
    modint operator-(modint that) const { return *this -= that; }
    modint operator*(modint that) const { return *this *= that; }
    modint operator/(modint that) const { return *this /= that; }
    modint operator%(modint that) const { return *this %= that; }
};
using mint = modint<1000000007>;
// Input/Output
ostream &operator<<(ostream& os, mint a) { return os << a.x; }
istream &operator>>(istream& is, mint &a) { return is >> a.x; }

ll binpow(ll x, ll e) {
  ll ret = 1, p = x;
  while(e > 0) {
    if(e&1) {(ret *= p) %= MOD; e--;}
    else {(p *= p) %= MOD; e /= 2;} 
  }
  return ret;
}

template<bool bigN=false>
ll combi(ll N_, ll K_, ll mo=MOD) {
    const int NUM_=1e6+10;
    static ll fact[NUM_+1]={},factr[NUM_+1]={},inv[NUM_+1]={};
    auto binpow = [&](ll x, ll e) -> ll {
        ll a = 1, p = x;
        while(e > 0) {
        if(e%2 == 0) {p = (p*p) % mo; e /= 2;}
        else {a = (a*p) % mo; e--;}
        }
        return a;
    };
    if (fact[0]==0) {
        fact[0] = factr[0] = inv[0] = 1;
        FOR(i, 1, NUM_+1) fact[i] = fact[i-1] * i % MOD;
        factr[NUM_] = binpow(fact[NUM_], mo-2);
        for(int i=NUM_-1; i>=0; --i) factr[i] = factr[i+1] * (i+1) % MOD;
        // bigNがないならいらない
        // REP(i, NUM_+1) inv[i] = binpow(i, MOD-2);
    }
    if(K_<0 || K_>N_) return 0;
    // 前計算 O(max(N,K)) クエリ O(1)
    if(!bigN) return factr[K_]*fact[N_]%MOD*factr[N_-K_]%MOD;
    // Nが大きいけどKが小さい場合に使う 前計算 O(Klog(mod)) クエリ O(K)
    ll ret = 1;
    for(;K_>0;N_--,K_--) (ret *= N_%MOD) %= MOD, (ret *= inv[K_]) %= MOD;
    return ret;
}

signed main(void)
{
    cin.tie(0);
    ios::sync_with_stdio(false);

    ll n, m, a, b;
    cin >> n >> m >> a >> b;

    mint ret = 0, fact = 1;
    FOR(e, 1, n) {
        mint tmp = fact;
        tmp *= combi(m-1, e-1);
        tmp *= binpow(m, n-1-e);
        if(e != n-1) {
            tmp *= e+1;
            tmp *= binpow(n, n-2-e);
        }
        fact *= n-e-1;
        ret += tmp;
    }
    cout << ret << endl;

    return 0;
}