ferinの競プロ帳

競プロについてのメモ

2015 ICL, Finals, Div. 1 J. Ceizenpok’s formula (nCk mod m の求め方)

問題ページ
Problem - J - Codeforces

概要

nCk mod m を求める。 1 <= n <= 10^18, 0 <= k <= n, 2 <= m <= 1000000

考えたこと

JAG夏合宿で見た問題
パット見nCk mod m求めるいつものやつだ!と思って制約を見たら真顔になった
n, kが大きすぎるのでn, kまでの階乗を計算しておくことができないしmが素数とは限らないので逆元が求まらないときもあるしわからなくて終了した
解説でlucasの定理とか中国剰余定理とか聞いたけどそのときはチンプンカンプンだった

{ }

中国剰余定理

2元の連立合同式について考える。
\begin{align} x \equiv a_1\ ({\rm mod}\ m_1) \cdots (1) \\ x \equiv a_2\ ({\rm mod}\ m_2) \cdots (2) \end{align} { m_1, m_2 }が互いに素であるとき { 0 \leq x \leq m_1m_2} の範囲に条件を満たす{ x }がただ一つ存在する。このとき、解xは拡張ユークリッドの互除法を利用して求めることができる。式(1)より
\begin{align} x = a_1 + zm_1 \end{align} となる。
\begin{align} x = a_2 = a_1 + zm_1\ ({\rm mod}\ m_2) \end{align} したがって、上式を満たすような zを見つければよい。 \begin{align} z = (a_2-a_1)k_1 (ただし k_1 \equiv m_1^{-1} {\rm mod}\ m_2) \end{align} と置いたとき上式が成り立つことを示す。 \begin{align} x = a_1 + (a_2-a_1)k_1m_1 \\ k_1m_1 \equiv 1 ({\rm mod}\ m_2) より x \equiv a_2 ({\rm mod}\ m_2) \end{align}

{m_1,m_2}は互いに素なことから逆元は存在する。逆元は拡張ユークリッドの式より求めることができるから{k_1}は求められる。したがって以下の式 \begin{align} x = a_1 + (a_2-a_1)k_1m_1 \end{align} より連立合同式の解{x}を求めることができる。

n個の合同式から成る連立合同式についても同様に解{x}を求めることができる。 \begin{align} x = a_i\ ({\rm mod}\ m_i) (1 \leq i \leq n) \end{align} まず{i=1,2}である2式について条件を満たす解を上の方法を用いて求める。すると、 \begin{align} x = a_{1,2}\ ({\rm mod}\ m_1m_2) \end{align} が求まる。次にこの式と{i=3}の式についての連立合同式の解を求める。これを繰り返していくことで最終的に条件を満たす解{x} \begin{align} x = a'\ ({\rm mod}\ m_1m_2\cdots m_n) \end{align} を求めることができる。

modの素因数分解

{C(n, k)\ ({\rm mod} m)}{m}素因数分解する。 \begin{align} m\ =\ m_1m_2\cdots m_r\ =\ p_1^{q_1}p_2^{q_2}\cdots p_r^{q_r} \end{align} このとき{C(n, k)\ ({\rm mod}\ m_i) (1 \leq i \leq r)}を求めることができたとする。するとr個の連立合同式が導出でき、中国剰余定理を用いることにより \begin{align} nCk\ ({\rm mod}\ m_1m_2\cdots m_r = m) \end{align} を求めることができる。

lucasの定理

では、{C(n, k)\ ({\rm mod}\ m_i) (1 \leq i \leq r)}はどのようにして求めればいいのか。これにはlucasの定理を用いる。素数{p}と非負整数{m,n}について
\begin{align} C(n,k) = \prod_{i=0}^{l} C(n_i, k_i)\ ({\rm mod}\ p) \end{align}
が成り立つ。ここで{n_i,k_i}
\begin{align} n = n_{0} + n_{1}p + \cdots + n_{l-1}p^{l-1} \\ k = k_{0} + k_{1}p + \cdots + k_{l-1}p^{l-1} \end{align}
を表す。したがって、modが素数のときは{O({\rm log}n)}で二項係数{C(n,k)}を求めることができる。

lucasの定理の拡張

modが素数の累乗{p^q}のときに二項係数{C(n,k)}を求める。
{r=n-k}とおく。{n_i,k_i,r_i}を上と同様にp進数に展開したときのi桁目と定義する。また、{N_j,K_j,R_j}{j}桁目から{j+q-1}桁目までの部分列と定義する。 \begin{align} N_j = n_{j} + n_{j+1}p + \cdots + n_{j+q-1}p^{q-1} \\ N_j = [n/p^{j}]\ ({\rm mod}\ p^{q}) \end{align} また、{e_{j}}を \begin{align} e_{j} = ([n/p^{j+1}] + [n/p^{j+2} + \cdots])-([r/p^{j+1}] + [r/p^{j+2} + \cdots]) \\ -([k/p^{j+1}] + [k/p^{j+2} + \cdots]) \end{align} と定義する。さらに{(k!)_{p}}をk以下でpの倍数でないものの積と定義する。これらを用いると二項係数{C(n,k)}は \begin{align} \frac{C(n,k)}{p^{e_0}} = (\pm 1)^{e_{q-1}} (\frac{(N_{0}!)_p}{(M_{0}!)_p(R_{0}!)_p})(\frac{(N_{1}!)_p}{(M_{1}!)_p(R_{1}!)_p}) \cdots (\frac{(N_{l}!)_p}{(M_{l}!)_p(R_{l}!)_p})\ ({\rm mod}\ p^{q}) \end{align} と書ける。{\pm 1}の部分は{p=2, q \geq 3}のときには1、それ以外のときには-1となる。

{e_{j}}{O({\rm log}n + {\rm log}r + {\rm log}k)}で求めることができる。{{k!}_p}はp進数q桁の数字であることを考えると最大でも{p^q}で抑えられる。{O(p^{q})}でとりうる{(k!)_{p}}を列挙する。また、{(k!)_{p}}にはpの倍数が含まれないことからmodの{p^q}とは互いに素である。したがって逆元が存在する。拡張ユークリッドの互除法を用いて逆元を計算することができる。

よってmodが素数の累乗のときの二項係数は計算することができる。

まとめ

まずmodを因数分解し、得られた因数をmodとする二項係数の値をlucasの定理の拡張を用いて求める。その後中国剰余定理で最終的な値を求める。

参考文献

nCr mod mの求め方 - uwicoder - アットウィキ
Coding-Templates/BinCoeff.pdf at master · rishirajsinghjhelumi/Coding-Templates · GitHub
競技プログラミングにおける数学的問題まとめ - はまやんはまやんはまやん
Lucasの定理とその証明 | 高校数学の美しい物語
中国剰余定理の証明と例題(二元の場合) | 高校数学の美しい物語

#define __USE_MINGW_ANSI_STDIO 0
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define int ll
typedef vector<int> VI;
typedef vector<VI> VVI;
typedef pair<int, int> PII;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()
#define IN(a, b, x) (a<=x&&x<b)
#define PB push_back

const ll LLINF = (1LL<<60);
const int INF = (1LL<<30);
const int MOD = 1000000007;

template <typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template <typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')';
  return out;
}

int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0};

// ax + by = gcd(a, b) となる {x, y, gcd(a, b)} を返す
// O(log(min(a, b)))
ll extgcd(ll a, ll b, ll &x, ll &y) {
  ll g = a; x = 1, y = 0;
  if(b != 0) g = extgcd(b, a%b, y, x), y -= (a/b) * x;
  return g;
}

// a^-1 mod n を返す 存在しなければ-1
// O(log(n))
ll inv(ll a, ll n) {
  ll s, t;
  extgcd(a, n, s, t);
  return (n+s) % n;
}

// 二分累乗法 x^e % mod O(log(e))
ll binpow(ll x, ll e, ll mod) {
  ll a = 1, p = x;
  while(e > 0) {
    if(e%2 == 0) {p = (p*p) % mod; e /= 2;}
    else {a = (a*p) % mod; e--;}
  }
  return a % mod;
}

// x = a1 mod m1, x2 = a2 mod m2 を解く オーバーフローには注意
// O(log(min(m1, m2)))
pair<ll, ll> crt(ll a1, ll a2, ll m1, ll m2) {
  auto normal = [](ll x, ll m) { return x>=-x ? x%m : m-(-x)%m; };
  auto modmul = [&normal](ll a, ll b, ll m) { return normal(a, m)*normal(b, m)%m; };
  ll k1, k2;
  ll g = extgcd(m1, m2, k1, k2);
  if(normal(a1, g) != normal(a2, g)) return {-1, -1};
  ll l = m1 / g * m2;
  ll x = a1 + modmul(modmul((a2-a1)/g, k1, l), m1, l);
  return {x, l};
}

pair<ll, ll> crt(vector<ll> a, vector<ll> m) {
  ll mod = 1, ans = 0;
  int n = a.size();
  REP(i, n) {
    tie(ans, mod) = crt(ans, a[i], mod, m[i]);
    if(ans == -1) return {-1, -1};
  }
  return {ans, mod};
}

ll fact[1000010], ifact[1000010];
void makeFac(ll p, ll q) {
  ll pr = 1;
  REP(i, q) pr *= p;
  fact[0] = ifact[0] = 1;
  FOR(i, 1, pr+1) {
    if(i%p == 0) {
      fact[i] = fact[i-1];
    } else {
      fact[i] = fact[i-1] * i % pr;
    }
    ifact[i] = inv(fact[i], pr);
  }
}

ll C(ll n, ll r, ll p, ll q) {
  if(n < 0 || r < 0 || r > n) return 0;
  // pr = p^q
  int pr = 1;
  REP(i, q) pr *= p;

  ll z = n-r;
  int e0 = 0;
  for(ll u=n/p;u;u/=p) e0 += u;
  for(ll u=r/p;u;u/=p) e0 -= u;
  for(ll u=z/p;u;u/=p) e0 -= u;
  int em = 0;
  for(ll u=n/pr;u;u/=p) em += u;
  for(ll u=r/pr;u;u/=p) em -= u;
  for(ll u=z/pr;u;u/=p) em -= u;

  ll ret = 1;
  while(n > 0) {
    ret = ret*fact[n%pr]%pr*ifact[r%pr]%pr*ifact[z%pr]%pr;
    n /= p; r /= p; z /= p;
  }
  (ret *= binpow(p, e0, pr)) %= pr;
  if(!(p==2 && q >= 3) && em%2) ret = (pr-ret) % pr;
  return ret;
}

ll func(ll n, ll r, ll mod) {
  ll x = mod;
  vector<ll> a, m;
  FOR(i, 2, mod+1) if(x%i == 0) {
    ll cnt=0, pr=1;
    while(x%i==0) x/=i, cnt++, pr*=i;
    makeFac(i, cnt);
    a.PB(C(n, r, i, cnt));
    m.PB(pr);
  }

  return crt(a, m).first;
}

signed main(void)
{
  ll n, m, mo;
  cin >> n >> m >> mo;
  cout << func(n, m, mo) << endl;

  return 0;
}