ferinの競プロ帳

競プロについてのメモ

COLOCON -Colopl programming contest 2018- Final D - Chaos of the Snuke World

問題ページ

順列を決めたときの転倒数の期待値を考えます.i \lt j について,X_{ij}j 番目のほうが小さいならば1,それ以外ならば0となる変数とします.このとき,転倒数の期待値は E \lbrack \sum_{i \lt j} X_{ij} \rbrack となります.期待値の線形性から \sum_{i \lt j} E \lbrack X_{ij} \rbrack と一致します.E \lbrack X_{ij} \rbrack A_i  \gt  A_j, A_i  \gt  B_j, B_i  \gt  A_j, B_i  \gt  B_j が1個成り立つごとに 1/4 増えます. 1/4 \times \sum_{i \lt j} (A_i  \gt  A_j) + (A_i  \gt  B_j) + (B_i  \gt  A_j) + (B_i  \gt  B_j) を最小化するような順列を求める問題になりました.

i,j 番目の要素について i の次に j と並べたときの転倒数の期待値が x/4 のとき,j の次に i と並べたときの転倒数の期待値は 1-x/4 となります.したがって,全ての2要素間について転倒数の期待値が 0,1/4,1/2 となるような並べ方ができれば,最適な並べ方です.

任意のペアについて A_i  \lt  B_i としても一般性は失いません.ペアを辞書順でソートすると,i \lt j について A_i  \gt  A_j, A_i  \gt  B_j となることはないため,全ての2要素間について転倒数の期待値は 1/2 以下となります.したがって,この並べ方について転倒数の期待値を求められればよいです.これは通常の転倒数と同様にBITを用いることで O(N\log N) で解けます.

期待値を0 or 1の確率変数で書いて線形性から分解するとわかりやすくなるやつ
とりあえずソートすると見通しがよくなるやつ

#include <bits/stdc++.h>    
using namespace std;    
using ll = long long;    
using PII = pair<ll, ll>;    
#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)    
#define REP(i, n) FOR(i, 0, n)    
#define ALL(x) x.begin(), x.end()    
template<typename T> void chmin(T &a, const T &b) { a = min(a, b); }    
template<typename T> void chmax(T &a, const T &b) { a = max(a, b); }    
struct FastIO {FastIO() { cin.tie(0); ios::sync_with_stdio(0); }}fastiofastio;    
#ifdef DEBUG_     
#include "../program_contest_library/memo/dump.hpp"    
#else    
#define dump(...)    
#endif    
const ll INF = 1LL<<60;  
  
template<ll MOD>  
struct modint {  
    ll x;  
    modint(): x(0) {}  
    modint(ll y) : x(y>=0 ? y%MOD : y%MOD+MOD) {}  
    static constexpr ll mod() { return MOD; }  
    // e乗  
    modint pow(ll e) {  
        ll a = 1, p = x;  
        while(e > 0) {  
            if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}  
            else {a = (a*p) % MOD; e--;}  
        }  
        return modint(a);  
    }  
    modint inv() const {  
        ll a=x, b=MOD, u=1, y=1, v=0, z=0;  
        while(a) {  
            ll q = b/a;  
            swap(z -= q*u, u);  
            swap(y -= q*v, v);  
            swap(b -= q*a, a);  
        }  
        return z;  
    }  
    // Comparators  
    bool operator <(modint b) { return x < b.x; }  
    bool operator >(modint b) { return x > b.x; }  
    bool operator<=(modint b) { return x <= b.x; }  
    bool operator>=(modint b) { return x >= b.x; }  
    bool operator!=(modint b) { return x != b.x; }  
    bool operator==(modint b) { return x == b.x; }  
    // Basic Operations  
    modint operator+(modint r) const { return modint(*this) += r; }  
    modint operator-(modint r) const { return modint(*this) -= r; }  
    modint operator*(modint r) const { return modint(*this) *= r; }  
    modint operator/(modint r) const { return modint(*this) /= r; }  
    modint &operator+=(modint r) {  
        if((x += r.x) >= MOD) x -= MOD;  
        return *this;  
    }  
    modint &operator-=(modint r) {  
        if((x -= r.x) < 0) x += MOD;  
        return *this;  
    }  
    modint &operator*=(modint r) {  
    #if !defined(_WIN32) || defined(_WIN64)  
        x = x * r.x % MOD; return *this;  
    #endif  
        unsigned long long y = x * r.x;  
        unsigned xh = (unsigned) (y >> 32), xl = (unsigned) y, d, m;  
        asm(  
            "divl %4; \n\t"  
            : "=a" (d), "=d" (m)  
            : "d" (xh), "a" (xl), "r" (MOD)  
        );  
        x = m;  
        return *this;  
    }  
    modint &operator/=(modint r) { return *this *= r.inv(); }  
    // increment, decrement  
    modint operator++() { x++; return *this; }  
    modint operator++(signed) { modint t = *this; x++; return t; }  
    modint operator--() { x--; return *this; }  
    modint operator--(signed) { modint t = *this; x--; return t; }  
    // 平方剰余のうち一つを返す なければ-1  
    friend modint sqrt(modint a) {  
        if(a == 0) return 0;  
        ll q = MOD-1, s = 0;  
        while((q&1)==0) q>>=1, s++;  
        modint z=2;  
        while(1) {  
            if(z.pow((MOD-1)/2) == MOD-1) break;  
            z++;  
        }  
        modint c = z.pow(q), r = a.pow((q+1)/2), t = a.pow(q);  
        ll m = s;  
        while(t.x>1) {  
            modint tp=t;  
            ll k=-1;  
            FOR(i, 1, m) {  
                tp *= tp;  
                if(tp == 1) { k=i; break; }  
            }  
            if(k==-1) return -1;  
            modint cp=c;  
            REP(i, m-k-1) cp *= cp;  
            c = cp*cp, t = c*t, r = cp*r, m = k;  
        }  
        return r.x;  
    }  
  
    template<class T>  
    friend modint operator*(T l, modint r) { return modint(l) *= r; }  
    template<class T>  
    friend modint operator+(T l, modint r) { return modint(l) += r; }  
    template<class T>  
    friend modint operator-(T l, modint r) { return modint(l) -= r; }  
    template<class T>  
    friend modint operator/(T l, modint r) { return modint(l) /= r; }  
    template<class T>  
    friend bool operator==(T l, modint r) { return modint(l) == r; }  
    template<class T>  
    friend bool operator!=(T l, modint r) { return modint(l) != r; }  
    // Input/Output  
    friend ostream &operator<<(ostream& os, modint a) { return os << a.x; }  
    friend istream &operator>>(istream& is, modint &a) {   
        is >> a.x;  
        a.x = ((a.x%MOD)+MOD)%MOD;  
        return is;  
    }  
    friend string to_frac(modint v) {  
        static map<ll, PII> mp;  
        if(mp.empty()) {  
            mp[0] = mp[MOD] = {0, 1};  
            FOR(i, 2, 1001) FOR(j, 1, i) if(__gcd(i, j) == 1) {  
                mp[(modint(i) / j).x] = {i, j};  
            }  
        }  
        auto itr = mp.lower_bound(v.x);  
        if(itr != mp.begin() && v.x - prev(itr)->first < itr->first - v.x) --itr;  
        string ret = to_string(itr->second.first + itr->second.second * ((int)v.x - itr->first));  
        if(itr->second.second > 1) {  
            ret += '/';  
            ret += to_string(itr->second.second);  
        }  
        return ret;  
    }  
};  
using mint = modint<1000000007>;   
  
template <typename T>  
struct BIT {  
    int n;  
    vector<T> bit;  
    BIT(int n_ = 1e5) { init(n_); }  
    void init(int sz) {   
        n=1; while(n < sz) n*=2;  
        bit.assign(n+1, 0);   
    }  
    void update(int i, T w) {  
        for(int x=i+1; x<(int)bit.size(); x += x&-x) bit[x] += w;  
    }  
    // [0,i]  
    T query(int i) {  
        T ret = 0;  
        for(int x=i+1; x>0; x -= x&-x) ret += bit[x];  
        return ret;  
    }  
    // 合計がw以上の最小の位置  
    int lower_bound(T w) {  
        int x = 0;  
        for(int k=n; k>0; k>>=1) {  
            if(x+k <= n && bit[x+k] < w) {  
                w -= bit[x+k];  
                x += k;  
            }  
        }  
        return x;  
    }  
};  
  
signed main() {  
    ll n;  
    cin >> n;  
    vector<ll> a(n), b(n);  
    REP(i, n) {  
        cin >> a[i] >> b[i];  
        a[i]--, b[i]--;  
        if(a[i] > b[i]) swap(a[i], b[i]);  
    }  
    vector<ll> ord(n);  
    iota(ALL(ord), 0);  
    sort(ALL(ord), [&](ll l, ll r){  
        return PII(a[l], b[l]) < PII(a[r], b[r]);  
    });  
  
    mint ret = 0;  
    BIT<ll> bit(2*n);  
    for(auto i: ord) {  
        ret += bit.query(2*n-1) - bit.query(a[i]);  
        ret += bit.query(2*n-1) - bit.query(b[i]);  
        bit.update(a[i], 1);  
        bit.update(b[i], 1);  
    }  
    ret /= 4;  
    ret *= mint(2).pow(n);  
    cout << ret << endl;  
  
    return 0;  
}