ARC055 C - ABCAC
問題ページ
C - ABCAC
公式解説のaとcを高速に求める方法について
解法
aとcをそれぞれ求めたあと a>0 && c>0 && a+c>=|y| を満たしていればABCACが構成できる文字列ABCの組み合わせは a+c-|y|+1 通りある。
SAとlcp配列とRMQ
SAとlcp配列を構築する。rank[sa[i]] = i とする。
文字列Sのi,jから始まる接尾辞の先頭共通文字数は区間 [rank[i], rank[j]) のうち最小のlcp配列の値なのでセグメントツリーなどを使えばO(logN)で求められる。
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) #define ALL(x) x.begin(), x.end() // suffix array int n, k1; int tmp1[200010], ran1[200010]; bool compare_sa1(int i, int j) { if(ran1[i] != ran1[j]) return ran1[i] < ran1[j]; else { int ri = i+k1<=n ? ran1[i+k1] : -1; int rj = j+k1<=n ? ran1[j+k1] : -1; return ri < rj; } } // O(nlog^2n) void construct_sa1(string s, int *sa) { n = s.size(); REP(i, n+1) sa[i] = i, ran1[i] = i<n ? s[i] : -1; for(k1 = 1; k1 <= n; k1*=2) { sort(sa, sa+n+1, compare_sa1); tmp1[sa[0]] = 0; FOR(i, 1, n+1) { tmp1[sa[i]] = tmp1[sa[i-1]] + (compare_sa1(sa[i-1], sa[i]) ? 1 : 0); } REP(i, n+1) { ran1[i] = tmp1[i]; } } } // O(n) void construct_lcp1(string s, int *sa, int *lcp) { n = s.size(); REP(i, n+1) ran1[sa[i]] = i; int h = 0; lcp[0] = 0; REP(i, n) { int j = sa[ran1[i] - 1]; if(h > 0) h--; for(; j+h<n && i+h<n; ++h) { if(s[j+h] != s[i+h]) break; } lcp[ran1[i]-1] = h; } } int k2; int tmp2[200010], ran2[200010]; bool compare_sa2(int i, int j) { if(ran2[i] != ran2[j]) return ran2[i] < ran2[j]; else { int ri = i+k2<=n ? ran2[i+k2] : -1; int rj = j+k2<=n ? ran2[j+k2] : -1; return ri < rj; } } // O(nlog^2n) void construct_sa2(string s, int *sa) { n = s.size(); REP(i, n+1) sa[i] = i, ran2[i] = i<n ? s[i] : -1; for(k2 = 1; k2 <= n; k2*=2) { sort(sa, sa+n+1, compare_sa2); tmp2[sa[0]] = 0; FOR(i, 1, n+1) { tmp2[sa[i]] = tmp2[sa[i-1]] + (compare_sa2(sa[i-1], sa[i]) ? 1 : 0); } REP(i, n+1) { ran2[i] = tmp2[i]; } } } // O(n) void construct_lcp2(string s, int *sa, int *lcp) { n = s.size(); REP(i, n+1) ran2[sa[i]] = i; int h = 0; lcp[0] = 0; REP(i, n) { int j = sa[ran2[i] - 1]; if(h > 0) h--; for(; j+h<n && i+h<n; ++h) { if(s[j+h] != s[i+h]) break; } lcp[ran2[i]-1] = h; } } // 文字列S,T 接尾辞配列sa O(|T|log|S|) で文字列検索をする bool contain(string s, string t, int *sa) { int a = 0, b = s.length(); while(b-a > 1) { int c = (a+b)/2; if(s.compare(sa[c], t.length(), t) < 0) a = c; else b = c; } return s.compare(sa[b], t.length(), t) == 0; } template<class T> class segmentTree { public: int size_; vector<T> dat; T init__ = INT_MAX; // 単位元 segmentTree() {} segmentTree(int n) { for(size_ = 1; size_ < n; size_ *= 2); dat.assign(2*size_-1, init__); } T calc(T d1, T d2) { return min(d1, d2); // minnimum // return max(d1, d2); // maximum // return d1+d2; // sum } T query(int a, int b, int k, int l, int r) { if(r <= a || b <= l) return init__; if(a <= l && r <= b) return dat[k]; return calc(query(a, b, 2*k+1, l, (l+r)/2), query(a, b, 2*k+2, (l+r)/2, r)); } T query(int a, int b) {return query(a, b, 0, 0, size_);} void update(int k, T a) { k += size_ - 1; dat[k] = a; // max or min // dat[k] += a; // sum while(k > 0) { k = (k-1) / 2; dat[k] = calc(dat[k*2+1], dat[k*2+2]); } } }; int sa1[200010], lcp1[200010]; int sa2[200010], lcp2[200010]; signed main(void) { string s; cin >> s; segmentTree<int> seg1(s.size()); construct_sa1(s, sa1); construct_lcp1(s, sa1, lcp1); REP(i, n) seg1.update(i, lcp1[i]); string s2 = s; reverse(ALL(s2)); segmentTree<int> seg2(s2.size()); construct_sa2(s2, sa2); construct_lcp2(s2, sa2, lcp2); REP(i, n) seg2.update(i, lcp2[i]); ll ret = 0; FOR(i, 1, n) { if(i <= n-i) continue; int l = ran1[0], r = ran1[i]; if(l > r) swap(l, r); int a = min(n-i-1, (ll)seg1.query(l, r)); l = ran2[n-i], r = ran2[0]; if(l > r) swap(l, r); int c = min(n-i-1, (ll)seg2.query(l, r)); if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1); } cout << ret << endl; return 0; }
z-algorithm
文字列SとS[i,|S|-1]の共通先頭文字数をO(|S|)で求めるという問題そのまんまなアルゴリズム。 snuke.hatenablog.com
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define int ll typedef vector<int> VI; #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) #define ALL(x) x.begin(), x.end() // z-algotirhm O(|S|) VI Zalgo(string s) { VI v(s.size()); v[0] = s.size(); int i = 1, j = 0; while (i < s.size()) { while (i+j < s.size() && s[j] == s[i+j]) ++j; v[i] = j; if (j == 0) { ++i; continue;} int k = 1; while (i+k < s.size() && k+v[k] < j) v[i+k] = v[k], ++k; i += k; j -= k; } return v; } signed main(void) { string s; cin >> s; int n = s.size(); VI v1 = Zalgo(s); string s2 = s; reverse(ALL(s2)); VI v2 = Zalgo(s2); ll ret = 0; FOR(i, 1, n) { if(i <= n-i) continue; int a = min(n-i-1, v1[i]); int c = min(n-i-1, v2[n-i]); if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1); } cout << ret << endl; return 0; }
ローリングハッシュ
ローリングハッシュを使うとある2つの文字列が一致しているかどうかの判定はO(1)でできる。文字列xとyの先頭d文字が共通かどうかという判定問題に置き換えdを最大化すればaが求まる。この判定はO(1)ででき、単調性があるので二分探索が可能なのでO(log|S|)でaとcをそれぞれ求めることができる。合計でO(|S|log|S|)で解ける。
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define int ll #define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i) #define REP(i, n) FOR(i, 0, n) #define ALL(x) x.begin(), x.end() // rolling-hash class rollingHash { public: static const int MAX_N = 200010; // MOD と 基数 ll mo[2] = {1000000007, 1000000009}; ll base[2] = {1009, 1007}; ll hash[2][MAX_N], power[2][MAX_N]; rollingHash() {} rollingHash(string s) { init(s); } // O(|S|) void init(string s) { REP(i, 2) { power[i][0] = 1; FOR(j, 1, MAX_N) power[i][j] = power[i][j-1]*base[i]%mo[i]; } REP(i, 2) REP(j, s.size()) { hash[i][j+1] = (hash[i][j]+power[i][j]*(s[j]-'a'))%mo[i]; } } // [l1, r1) と [l2, r2) が一致するかの判定 (l1 < l2) bool equal(int l1, int r1, int l2, int r2) { REP(i, 2) { ll a = (((hash[i][r1]-hash[i][l1])%mo[i])+mo[i])%mo[i]; ll b = (((hash[i][r2]-hash[i][l2])%mo[i])+mo[i])%mo[i]; if(a*power[i][l2-l1]%mo[i] == b) return true; } return false; } }; signed main(void) { string s, s2; cin >> s; int n = s.size(); s2 = s; reverse(ALL(s2)); rollingHash hs1(s), hs2(s2); ll ret = 0; FOR(i, 1, n) { if(i <= n-i) continue; int lb = 0, ub = n-i; while(ub-lb > 1) { int mid = (lb+ub)/2; if(hs1.equal(0, mid, i, i+mid)) { lb = mid; } else { ub = mid; } } int a = lb; lb = 0, ub = n-i; while(ub-lb > 1) { int mid = (lb+ub)/2; if(hs1.equal(i-mid, i, n-mid, n)) { lb = mid; } else { ub = mid; } } int c = lb; if(a>0 && c>0 && a+c >= n-i) ret += min(n-i-1, a+c-(n-i)+1); } cout << ret << endl; return 0; }