ferinの競プロ帳

競プロについてのメモ

ARC055 C - ABCAC

問題ページ
C - ABCAC

公式解説のaとcを高速に求める方法について

解法

aとcをそれぞれ求めたあと a>0 && c>0 && a+c>=|y| を満たしていればABCACが構成できる文字列ABCの組み合わせは a+c-|y|+1 通りある。

SAとlcp配列とRMQ

SAとlcp配列を構築する。rank[sa[i]] = i とする。
文字列Sのi,jから始まる接尾辞の先頭共通文字数は区間 [rank[i], rank[j]) のうち最小のlcp配列の値なのでセグメントツリーなどを使えばO(logN)で求められる。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()

// suffix array
int n, k1;
int tmp1[200010], ran1[200010];

bool compare_sa1(int i, int j) {
  if(ran1[i] != ran1[j]) return ran1[i] < ran1[j];
  else {
    int ri = i+k1<=n ? ran1[i+k1] : -1;
    int rj = j+k1<=n ? ran1[j+k1] : -1;
    return ri < rj;
  }
}

// O(nlog^2n)
void construct_sa1(string s, int *sa) {
  n = s.size();
  REP(i, n+1) sa[i] = i, ran1[i] = i<n ? s[i] : -1;

  for(k1 = 1; k1 <= n; k1*=2) {
    sort(sa, sa+n+1, compare_sa1);

    tmp1[sa[0]] = 0;
    FOR(i, 1, n+1) {
      tmp1[sa[i]] = tmp1[sa[i-1]] + (compare_sa1(sa[i-1], sa[i]) ? 1 : 0);
    }
    REP(i, n+1) {
      ran1[i] = tmp1[i];
    }
  }
}

// O(n)
void construct_lcp1(string s, int *sa, int *lcp) {
  n = s.size();
  REP(i, n+1) ran1[sa[i]] = i;

  int h = 0;
  lcp[0] = 0;
  REP(i, n) {
    int j = sa[ran1[i] - 1];
    if(h > 0) h--;
    for(; j+h<n && i+h<n; ++h) {
      if(s[j+h] != s[i+h]) break;
    }

    lcp[ran1[i]-1] = h;
  }
}

int k2;
int tmp2[200010], ran2[200010];

bool compare_sa2(int i, int j) {
  if(ran2[i] != ran2[j]) return ran2[i] < ran2[j];
  else {
    int ri = i+k2<=n ? ran2[i+k2] : -1;
    int rj = j+k2<=n ? ran2[j+k2] : -1;
    return ri < rj;
  }
}

// O(nlog^2n)
void construct_sa2(string s, int *sa) {
  n = s.size();
  REP(i, n+1) sa[i] = i, ran2[i] = i<n ? s[i] : -1;

  for(k2 = 1; k2 <= n; k2*=2) {
    sort(sa, sa+n+1, compare_sa2);

    tmp2[sa[0]] = 0;
    FOR(i, 1, n+1) {
      tmp2[sa[i]] = tmp2[sa[i-1]] + (compare_sa2(sa[i-1], sa[i]) ? 1 : 0);
    }
    REP(i, n+1) {
      ran2[i] = tmp2[i];
    }
  }
}

// O(n)
void construct_lcp2(string s, int *sa, int *lcp) {
  n = s.size();
  REP(i, n+1) ran2[sa[i]] = i;

  int h = 0;
  lcp[0] = 0;
  REP(i, n) {
    int j = sa[ran2[i] - 1];
    if(h > 0) h--;
    for(; j+h<n && i+h<n; ++h) {
      if(s[j+h] != s[i+h]) break;
    }

    lcp[ran2[i]-1] = h;
  }
}

// 文字列S,T 接尾辞配列sa O(|T|log|S|) で文字列検索をする
bool contain(string s, string t, int *sa) {
  int a = 0, b = s.length();
  while(b-a > 1) {
    int c = (a+b)/2;
    if(s.compare(sa[c], t.length(), t) < 0) a = c;
    else b = c;
  }
  return s.compare(sa[b], t.length(), t) == 0;
}

template<class T>
class segmentTree {
public:
  int size_;
  vector<T> dat;
  T init__ = INT_MAX; // 単位元

  segmentTree() {}
  segmentTree(int n) {
    for(size_ = 1; size_ < n; size_ *= 2);
    dat.assign(2*size_-1, init__);
  }

  T calc(T d1, T d2) {
    return min(d1, d2);     // minnimum
    // return max(d1, d2);  // maximum
    // return d1+d2;        // sum
  }
  T query(int a, int b, int k, int l, int r) {
    if(r <= a || b <= l) return init__;
    if(a <= l && r <= b) return dat[k];
    return calc(query(a, b, 2*k+1, l, (l+r)/2),
                query(a, b, 2*k+2, (l+r)/2, r));
  }
  T query(int a, int b) {return query(a, b, 0, 0, size_);}
  void update(int k, T a) {
    k += size_ - 1;
    dat[k] = a;      // max or min
    // dat[k] += a;  // sum
    while(k > 0) {
      k = (k-1) / 2;
      dat[k] = calc(dat[k*2+1], dat[k*2+2]);
    }
  }
};

int sa1[200010], lcp1[200010];
int sa2[200010], lcp2[200010];
signed main(void)
{
  string s;
  cin >> s;

  segmentTree<int> seg1(s.size());
  construct_sa1(s, sa1); construct_lcp1(s, sa1, lcp1);
  REP(i, n) seg1.update(i, lcp1[i]);

  string s2 = s;
  reverse(ALL(s2));
  segmentTree<int> seg2(s2.size());
  construct_sa2(s2, sa2); construct_lcp2(s2, sa2, lcp2);
  REP(i, n) seg2.update(i, lcp2[i]);

  ll ret = 0;
  FOR(i, 1, n) {
    if(i <= n-i) continue;
    int l = ran1[0], r = ran1[i];
    if(l > r) swap(l, r);
    int a = min(n-i-1, (ll)seg1.query(l, r));

    l = ran2[n-i], r = ran2[0];
    if(l > r) swap(l, r);
    int c = min(n-i-1, (ll)seg2.query(l, r));

    if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1);
  }
  cout << ret << endl;

  return 0;
}

z-algorithm

文字列SとS[i,|S|-1]の共通先頭文字数をO(|S|)で求めるという問題そのまんまなアルゴリズムsnuke.hatenablog.com

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define int ll
typedef vector<int> VI;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()

// z-algotirhm O(|S|)
VI Zalgo(string s) {
  VI v(s.size());
  v[0] = s.size();
  int i = 1, j = 0;
  while (i < s.size()) {
    while (i+j < s.size() && s[j] == s[i+j]) ++j;
    v[i] = j;
    if (j == 0) { ++i; continue;}
    int k = 1;
    while (i+k < s.size() && k+v[k] < j) v[i+k] = v[k], ++k;
    i += k; j -= k;
  }
    return v;
}

signed main(void)
{
  string s;
  cin >> s;
  int n = s.size();

  VI v1 = Zalgo(s);

  string s2 = s;
  reverse(ALL(s2));
  VI v2 = Zalgo(s2);

  ll ret = 0;
  FOR(i, 1, n) {
    if(i <= n-i) continue;
    int a = min(n-i-1, v1[i]);
    int c = min(n-i-1, v2[n-i]);

    if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1);
  }
  cout << ret << endl;

  return 0;
}

ローリングハッシュ

ローリングハッシュを使うとある2つの文字列が一致しているかどうかの判定はO(1)でできる。文字列xとyの先頭d文字が共通かどうかという判定問題に置き換えdを最大化すればaが求まる。この判定はO(1)ででき、単調性があるので二分探索が可能なのでO(log|S|)でaとcをそれぞれ求めることができる。合計でO(|S|log|S|)で解ける。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define int ll

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()

// rolling-hash
class rollingHash {
public:
  static const int MAX_N = 200010;
  // MOD と 基数
  ll mo[2] = {1000000007, 1000000009};
  ll base[2] = {1009, 1007};
  ll hash[2][MAX_N], power[2][MAX_N];

  rollingHash() {}
  rollingHash(string s) { init(s); }

  // O(|S|)
  void init(string s) {
    REP(i, 2) {
      power[i][0] = 1;
      FOR(j, 1, MAX_N) power[i][j] = power[i][j-1]*base[i]%mo[i];
    }
    REP(i, 2) REP(j, s.size()) {
      hash[i][j+1] = (hash[i][j]+power[i][j]*(s[j]-'a'))%mo[i];
    }
  }

  // [l1, r1) と [l2, r2) が一致するかの判定 (l1 < l2)
  bool equal(int l1, int r1, int l2, int r2) {
    REP(i, 2) {
      ll a = (((hash[i][r1]-hash[i][l1])%mo[i])+mo[i])%mo[i];
      ll b = (((hash[i][r2]-hash[i][l2])%mo[i])+mo[i])%mo[i];
      if(a*power[i][l2-l1]%mo[i] == b) return true;
    }
    return false;
  }
};

signed main(void)
{
  string s, s2;
  cin >> s;

  int n = s.size();
  s2 = s;
  reverse(ALL(s2));
  rollingHash hs1(s), hs2(s2);

  ll ret = 0;
  FOR(i, 1, n) {
    if(i <= n-i) continue;

    int lb = 0, ub = n-i;
    while(ub-lb > 1) {
      int mid = (lb+ub)/2;
      if(hs1.equal(0, mid, i, i+mid)) {
        lb = mid;
      } else {
        ub = mid;
      }
    }
    int a = lb;

    lb = 0, ub = n-i;
    while(ub-lb > 1) {
      int mid = (lb+ub)/2;
      if(hs1.equal(i-mid, i, n-mid, n)) {
        lb = mid;
      } else {
        ub = mid;
      }
    }
    int c = lb;

    if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1);
  }
  cout << ret << endl;

  return 0;
}