ferinの競プロ帳

競プロについてのメモ

SoundHound Programming Contest 2018 Masters Tournament 本戦 C - Not Too Close

問題ページ

考えたこと

長さ D のパスグラフを作ったあと、D 未満の経路ができないように辺をつなぐみたいな感じで考えてた
ラベルを無視して数え上げたあと頂点にラベルをつければいいかと思ったら、ラベルをつけるパートが大変でダブらずに数える方法が何もわからなくて終了
ラベルを無視していいなら包除っぽく最短距離が X 以下になるやつを数えていけばできると思ったんだけど、合ってるかはわからん

解法

頂点1からの距離でグラフを層に分解してDPをします。\text{dp} \lbrack i \rbrack  \lbrack j \rbrack  \lbrack k \rbrack  = ( 距離 i までで残り頂点が j で距離 i の頂点が k) とします。DPの遷移は距離 i+1 で頂点を l 個使用するとすると、\text{dp} \lbrack i+1 \rbrack  \lbrack j-l \rbrack  \lbrack l \rbrack  = \text{dp} \lbrack i \rbrack  \lbrack j \rbrack  \lbrack k \rbrack  \times \binom{j-1}{l} \times 2^{l(l-1)/2} \times (2^k-1)^l となります。

  • \binom{j-1}{l}
    残りの j-1 個から使用する l 個を取り出す方法
    頂点2は最後まで取っておかないといけないので j-1
  • 2^{l(l-1)/2}
    新たに使用する l 個の間の辺を張る方法
  • (2^k-1)^l
    l 個の頂点が距離 i のいずれかの頂点と結ばれる方法
    距離 i の全ての頂点と結ばれない方法はだめなので-1

このDPで距離が D-1 のところまで計算を行い、残りの頂点(2を含む)をどのようにつなげるか計算します。これは \sum \text{dp} \lbrack d-1 \rbrack  \lbrack i \rbrack  \lbrack j \rbrack  \times (2^j-1) \times (2^{j})^{i-1} \times 2^{i(i-1)/2} で計算できます。

  • 2^j-1
    頂点2を距離 D-1 の頂点とつなげる方法
  • (2^{j})^{i-1}
    2以外の残っている頂点を距離 D-1 の頂点とつなげる方法
    必ずしも繋げる必要はないので-1はいらない
  • 2^{i(i-1)/2}
    残っている頂点同士を結ぶ方法

感想

解説見て遷移考えてDP書いたら i=D-1 でだいぶハマった

ToDo: グラフが与えられたときに、ラベル付きグラフとして同型なものは重複してカウントしない条件で、ラベルの付け方が何通りあるか求めるのって可能なのか…?
ラベルをつけるのが難しい(のか?)と知っていればDPなり全探索系をするしかないと思えそう

層ごとに分割することで一つ前の層だけを見ればよくなって、DPできるようになるパターンは覚えておくべきっぽい

#include <bits/stdc++.h>    
using namespace std;    
using ll = long long;    
using PII = pair<ll, ll>;    
#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)    
#define REP(i, n) FOR(i, 0, n)    
#define ALL(x) x.begin(), x.end()    
template<typename T> void chmin(T &a, const T &b) { a = min(a, b); }    
template<typename T> void chmax(T &a, const T &b) { a = max(a, b); }    
struct FastIO {FastIO() { cin.tie(0); ios::sync_with_stdio(0); }}fastiofastio;    
#ifdef DEBUG_     
#include "../program_contest_library/memo/dump.hpp"    
#else    
#define dump(...)    
#endif    
const ll INF = 1LL<<60;  
  
template<ll MOD>  
struct modint {  
    ll x;  
    modint(): x(0) {}  
    modint(ll y) : x(y>=0 ? y%MOD : y%MOD+MOD) {}  
    static constexpr ll mod() { return MOD; }  
    // e乗  
    modint pow(ll e) {  
        ll a = 1, p = x;  
        while(e > 0) {  
            if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}  
            else {a = (a*p) % MOD; e--;}  
        }  
        return modint(a);  
    }  
    modint inv() const {  
        ll a=x, b=MOD, u=1, y=1, v=0, z=0;  
        while(a) {  
            ll q = b/a;  
            swap(z -= q*u, u);  
            swap(y -= q*v, v);  
            swap(b -= q*a, a);  
        }  
        return z;  
    }  
    // Comparators  
    bool operator <(modint b) { return x < b.x; }  
    bool operator >(modint b) { return x > b.x; }  
    bool operator<=(modint b) { return x <= b.x; }  
    bool operator>=(modint b) { return x >= b.x; }  
    bool operator!=(modint b) { return x != b.x; }  
    bool operator==(modint b) { return x == b.x; }  
    // Basic Operations  
    modint operator+(modint r) const { return modint(*this) += r; }  
    modint operator-(modint r) const { return modint(*this) -= r; }  
    modint operator*(modint r) const { return modint(*this) *= r; }  
    modint operator/(modint r) const { return modint(*this) /= r; }  
    modint &operator+=(modint r) {  
        if((x += r.x) >= MOD) x -= MOD;  
        return *this;  
    }  
    modint &operator-=(modint r) {  
        if((x -= r.x) < 0) x += MOD;  
        return *this;  
    }  
    modint &operator*=(modint r) {  
    #if !defined(_WIN32) || defined(_WIN64)  
        x = x * r.x % MOD; return *this;  
    #endif  
        unsigned long long y = x * r.x;  
        unsigned xh = (unsigned) (y >> 32), xl = (unsigned) y, d, m;  
        asm(  
            "divl %4; \n\t"  
            : "=a" (d), "=d" (m)  
            : "d" (xh), "a" (xl), "r" (MOD)  
        );  
        x = m;  
        return *this;  
    }  
    modint &operator/=(modint r) { return *this *= r.inv(); }  
    // increment, decrement  
    modint operator++() { x++; return *this; }  
    modint operator++(signed) { modint t = *this; x++; return t; }  
    modint operator--() { x--; return *this; }  
    modint operator--(signed) { modint t = *this; x--; return t; }  
    // 平方剰余のうち一つを返す なければ-1  
    friend modint sqrt(modint a) {  
        if(a == 0) return 0;  
        ll q = MOD-1, s = 0;  
        while((q&1)==0) q>>=1, s++;  
        modint z=2;  
        while(1) {  
            if(z.pow((MOD-1)/2) == MOD-1) break;  
            z++;  
        }  
        modint c = z.pow(q), r = a.pow((q+1)/2), t = a.pow(q);  
        ll m = s;  
        while(t.x>1) {  
            modint tp=t;  
            ll k=-1;  
            FOR(i, 1, m) {  
                tp *= tp;  
                if(tp == 1) { k=i; break; }  
            }  
            if(k==-1) return -1;  
            modint cp=c;  
            REP(i, m-k-1) cp *= cp;  
            c = cp*cp, t = c*t, r = cp*r, m = k;  
        }  
        return r.x;  
    }  
  
    template<class T>  
    friend modint operator*(T l, modint r) { return modint(l) *= r; }  
    template<class T>  
    friend modint operator+(T l, modint r) { return modint(l) += r; }  
    template<class T>  
    friend modint operator-(T l, modint r) { return modint(l) -= r; }  
    template<class T>  
    friend modint operator/(T l, modint r) { return modint(l) /= r; }  
    template<class T>  
    friend bool operator==(T l, modint r) { return modint(l) == r; }  
    template<class T>  
    friend bool operator!=(T l, modint r) { return modint(l) != r; }  
    // Input/Output  
    friend ostream &operator<<(ostream& os, modint a) { return os << a.x; }  
    friend istream &operator>>(istream& is, modint &a) {   
        is >> a.x;  
        a.x = ((a.x%MOD)+MOD)%MOD;  
        return is;  
    }  
    friend string to_frac(modint v) {  
        static map<ll, PII> mp;  
        if(mp.empty()) {  
            mp[0] = mp[MOD] = {0, 1};  
            FOR(i, 2, 1001) FOR(j, 1, i) if(__gcd(i, j) == 1) {  
                mp[(modint(i) / j).x] = {i, j};  
            }  
        }  
        auto itr = mp.lower_bound(v.x);  
        if(itr != mp.begin() && v.x - prev(itr)->first < itr->first - v.x) --itr;  
        string ret = to_string(itr->second.first + itr->second.second * ((int)v.x - itr->first));  
        if(itr->second.second > 1) {  
            ret += '/';  
            ret += to_string(itr->second.second);  
        }  
        return ret;  
    }  
};  
using mint = modint<1000000007>;  
  
// 前計算O(N) クエリO(1)  
mint combi(ll N, ll K) {  
    const int maxN=5e5; // !!!  
    static mint fact[maxN+1]={},factr[maxN+1]={};  
    if (fact[0]==0) {  
        fact[0] = factr[0] = 1;  
        FOR(i, 1, maxN+1) fact[i] = fact[i-1] * i;  
        factr[maxN] = fact[maxN].inv();  
        for(ll i=maxN-1; i>=0; --i) factr[i] = factr[i+1] * (i+1);  
    }  
    if(K<0 || K>N) return 0; // !!!  
    return factr[K]*fact[N]*factr[N-K];  
}  
  
mint dp[50][50][50];  
signed main() {  
    ll n, d;  
    cin >> n >> d;  
  
    vector<mint> pw2(1000);  
    pw2[0] = 1;  
    FOR(i, 1, 1000) pw2[i] = pw2[i-1] * 2;  
  
    dp[0][n-1][1] = 1;  
    REP(i, d-1) FOR(j, 1, n+1) REP(k, n+1) {  
        if(dp[i][j][k] == 0) continue;  
        mint p = pw2[k]-1;  
        FOR(l, 1, j+1) {  
            dp[i+1][j-l][l] += dp[i][j][k] * combi(j-1, l) * pw2[l*(l-1)/2] * p;  
            p *= pw2[k]-1;  
        }  
    }  
  
    mint ret = 0;  
    // 距離d-1の頂点がj個 残りがi個  
    // 2とつなげる方法が 2^j-1  
    // 2以外とつなげる方法が 2^j^i  
    // 残り分をつなげる方法が 2^(i*(i-1)/2)  
    FOR(i, 1, n+1) FOR(j, 1, n+1) ret += dp[d-1][i][j] * (pw2[j]-1) * pw2[j].pow(i-1) * pw2[i*(i-1)/2];  
    cout << ret << endl;  
  
    return 0;  
}