ferinの競プロ帳

競プロについてのメモ

ARC090 E - Avoiding Collision

問題ページ
E - Avoiding Collision

解法

 ある頂点uから頂点vへの最短経路が何通りあるか求めたい。ある辺(u,v)を最短経路に使う可能性が存在するかどうかは d[u] + cost = d[v] で判定できる。ある無向辺(u,v)が存在したとき最短経路にu→vとv→uの経路を両方使うことはありえない。また、負の辺が存在しないので最短経路に閉路が存在することはない。したがって最短経路に使う可能性がある辺の集合はDAGになる。このDAG上での経路は全て最短経路となる。DAG上で頂点uから頂点vへの経路はDPすることで求められるので最短経路が何通りあるか求めることができる。
 上の方法で頂点u→Tへの最短経路の種類数、頂点v→Tへの最短経路の種類数がそれぞれわかる。これをそれぞれdp1[u]、dp2[v]とする。dp1[T]^2が二人が出会う経路を含んだ最短路の選び方の組の総数になる。
 最短経路上ですれちがう可能性があるのは(最短経路長)/2の地点だけである。ある頂点uで出会う⇔d[u]*2 = d[t] でこの経路数は(dp1[u]*dp2[u])^2である。ある辺(u,v)で出会う⇔d[u]*2 < d[t] && d[v]*2 > d[t] && d[u] + 辺のcost = d[v] でこの経路数は(dp1[u]*dp2[v])^2となる。
 したがってdp[T]^2 から出会う場合の経路数を引いた数が答えである。

学び

  • あるグラフで最短経路数を求める方法
  • すれ違う可能性があるのは(最短経路長)/2の地点だけ
  • (求めたいパターン) = (全パターン) - (だめなパターン)
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define int ll
typedef vector<int> VI;
typedef vector<VI> VVI;
typedef pair<int, int> PII;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()
#define IN(a, b, x) (a<=x&&x<b)
#define PB push_back

const ll LLINF = (1LL<<60);
const int INF = (1LL<<30);
const int MOD = 1000000007;

template <typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template <typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }
template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')';
  return out;
}

int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0};

template<unsigned MOD>
class ModInt {
public:
  unsigned x;
  ModInt(): x(0) { }
  ModInt(signed y) : x(y >= 0 ? y % MOD : MOD - (-y) % MOD) {}
  unsigned get() const { return x; }

  // 逆数
  ModInt inv() const {
    ll a = 1, p = x, e = MOD-2;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }
  // e乗
  ModInt pow(ll e) {
    ll a = 1, p = x;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }
  // 2のx乗
  ModInt pow2() {
    ll a = 1, p = 2, e = x;
    while(e > 0) {
      if(e%2 == 0) {p = (p*p) % MOD; e /= 2;}
      else {a = (a*p) % MOD; e--;}
    }
    a %= MOD;
    return ModInt(a);
  }

  // Comparators
  bool operator <(ModInt b) { return x < b.x; }
  bool operator >(ModInt b) { return x > b.x; }
  bool operator<=(ModInt b) { return x <= b.x; }
  bool operator>=(ModInt b) { return x >= b.x; }
  bool operator!=(ModInt b) { return x != b.x; }
  bool operator==(ModInt b) { return x == b.x; }

  // increment, decrement
  ModInt operator++() { x++; return *this; }
  ModInt operator--() { x--; return *this; }

  // Basic Operations
  ModInt &operator+=(ModInt that) {
    x = ((ll)x+that.x)%MOD;
    return *this;
  }
  ModInt &operator-=(ModInt that) {
    x = ((((ll)x-that.x)%MOD)+MOD)%MOD;
    return *this;
  }
  ModInt &operator*=(ModInt that) {
    x = (ll)x * that.x % MOD;
    return *this;
  }
  // O(log(mod))かかるので注意
  ModInt &operator/=(ModInt that) {
    x = (ll)x * that.inv() % MOD;
    return *this;
  }
  ModInt &operator%=(ModInt that) {
    x = (ll)x % that.x;
    return *this;
  }
  ModInt operator+(ModInt that)const{return ModInt(*this) += that;}
  ModInt operator-(ModInt that)const{return ModInt(*this) -= that;}
  ModInt operator*(ModInt that)const{return ModInt(*this) *= that;}
  ModInt operator/(ModInt that)const{return ModInt(*this) /= that;}
  ModInt operator%(ModInt that)const{return ModInt(*this) %= that;}
};
typedef ModInt<1000000007> mint;
// Input/Output
ostream &operator<<(ostream& os, mint a) { return os << a.x; }
istream &operator>>(istream& is, mint &a) { return is >> a.x; }

int a[200010], b[200010], c[200010], d[100010];
vector<PII> g[100010];
mint dp1[100010], dp2[100010];
signed main(void)
{
  int n, m, s, t;
  cin >> n >> m >> s >> t;
  s--, t--;
  REP(i, m) {
    cin >> a[i] >> b[i] >> c[i];
    a[i]--, b[i]--;
    g[a[i]].PB({b[i], c[i]});
    g[b[i]].PB({a[i], c[i]});
  }

  REP(i, n) d[i] = LLINF;
  d[s] = 0;
  priority_queue<PII, vector<PII>, greater<PII>> que;
  que.push({d[s], s});

  while(que.size()) {
    PII p = que.top(); que.pop();
    if(p.second == t) continue;
    if(p.first > d[p.second]) continue;
    for(PII e: g[p.second]) {
      if(d[e.first] > d[p.second] + e.second) {
        d[e.first] = d[p.second] + e.second;
        que.push({d[e.first], e.first});
      }
    }
  }

  vector<PII> vec;
  REP(i, n) vec.PB({d[i], i});
  sort(ALL(vec));

  dp1[s] = 1;
  REP(i, n) {
    for(PII e: g[vec[i].second]) {
      if(d[vec[i].second] + e.second == d[e.first]) {
        dp1[e.first] += dp1[vec[i].second];
      }
    }
  }
  dp2[t] = 1;
  for(int i=n-1; i>=0; --i) {
    for(PII e: g[vec[i].second]) {
      if(d[vec[i].second] == d[e.first] + e.second) {
        dp2[e.first] += dp2[vec[i].second];
      }
    }
  }

  mint ret = dp1[t]*dp1[t];
  REP(i, n) {
    if(d[i]*2 == d[t]) {
      ret -= dp1[i]*dp1[i]*dp2[i]*dp2[i];
    }
  }
  REP(i, m) {
    int u = a[i], v = b[i];
    if(d[u] > d[v]) swap(u, v);
    if(2*d[u] < d[t] && 2*d[v] > d[t] && d[u] + c[i] == d[v]) {
      ret -= dp1[u]*dp1[u]*dp2[v]*dp2[v];
    }
  }
  cout << ret << endl;

  return 0;
}