ferinの競プロ帳

競プロについてのメモ

yukicoder No.776 A Simple RMQ Problem

問題ページ

解法

この記事(Maximum Subarray Sum in a given Range - GeeksforGeeks)のセグメント木を使って解いた。このセグ木では区間[l,r)の連続した部分列の区間和のmaxを求めることができる。セグ木の各頂点に「区間和」「maximum prefix sum」「maximum prefix sum」「部分列の区間和のmax」の4つの情報を持たせる。これらの情報sum,psum,ssum,maxの4つをもたせた構造体nodeを各頂点として扱う。
このセグ木を使ってmaxクエリに答える。まずr1<l1であればr1=l1、r2<l2であればl2=r2とできるのでr1>=l1,r2>=l2と考える。l2 < r1であれば答えは(区間[l1,l2]のssum) + (区間(l2,r1)のsum) + (区間[r1,r2]のpsum)となる。l2 >= r1であれば答えはmax(区間[l1,r1)のssum+区間[r1,l2]のsum+区間(l2,r2]のpsum、区間[l1,r1)のssum+[r1,l2]のsum、区間[r1,l2]のssum+区間(l2,r2]のpsum、区間[r1,l2]のmax)となる。

#include <bits/stdc++.h>
 
using namespace std;
using ll = long long;
// #define int ll
using PII = pair<ll, ll>;
 
#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define ALL(x) x.begin(), x.end()
 
template<typename T> T &chmin(T &a, const T &b) { return a = min(a, b); }
template<typename T> T &chmax(T &a, const T &b) { return a = max(a, b); }
template<typename T> bool IN(T a, T b, T x) { return a<=x&&x<b; }
template<typename T> T ceil(T a, T b) { return a/b + !!(a%b); }
 
template<typename T> vector<T> make_v(size_t a) { return vector<T>(a); }
template<typename T,typename... Ts>
auto make_v(size_t a,Ts... ts) { 
  return vector<decltype(make_v<T>(ts...))>(a,make_v<T>(ts...));
}
template<typename T,typename V> typename enable_if<is_class<T>::value==0>::type
fill_v(T &t, const V &v) { t=v; }
template<typename T,typename V> typename enable_if<is_class<T>::value!=0>::type
fill_v(T &t, const V &v ) { for(auto &e:t) fill_v(e,v); }
 
template<class S,class T>
ostream &operator <<(ostream& out,const pair<S,T>& a){
  out<<'('<<a.first<<','<<a.second<<')'; return out;
}
template<typename T>
istream& operator >> (istream& is, vector<T>& vec){
  for(T& x: vec) {is >> x;} return is;
}
template<class T>
ostream &operator <<(ostream& out,const vector<T>& a){
  out<<'['; for(T i: a) {out<<i<<',';} out<<']'; return out;
}
 
int dx[] = {0, 1, 0, -1}, dy[] = {1, 0, -1, 0}; // DRUL
const int INF = 1<<30;
const ll LLINF = 1LL<<60;
const int MOD = 1000000007;

/**
* @brief セグメント木
* @details 遅延評価をしない普通のセグメント木\n
* 点更新区間min d=INF, f=min(a,b), g=b\n
* 点更新区間max d=-INF, f=max(a,b), g=b\n
* 点加算区間和  d=0, f=a+b, g+=b
*/
template <typename T>
class segtree {
public:
  int n;
  vector<T> dat;
  T d;
  function<T(T,T)> f, g;

  segtree(int n_, function<T(T,T)> f_, function<T(T,T)> g_, T d_) 
    : f(f_), g(g_), d(d_) {
    n = 1;
    while(n < n_) n *= 2;
    dat.assign(n*2, d); 
  }
  // [a, b)
  T query(int a, int b, int k, int l, int r) {
    if(r <= a || b <= l) return d;
    if(a <= l && r <= b) return dat[k];
    return f(query(a, b, 2*k+1, l, (l+r)/2),
             query(a, b, 2*k+2, (l+r)/2, r));
  }
  T query(int a, int b) {return query(a, b, 0, 0, n);}
  void update(int i, T v) {
    i += n-1;
    dat[i] = g(dat[i], v);
    while(i > 0) {
      i = (i-1)/2;
      dat[i] = f(dat[i*2+1], dat[i*2+2]);
    }
  }
};
signed main(void)
{
  cin.tie(0);
  ios::sync_with_stdio(false);

  ll n, q;
  cin >> n >> q;
  vector<ll> a(n);
  REP(i, n) cin >> a[i];

  struct node {
    ll sum, psum, ssum, max;
    node() {}
    node(ll a, ll b, ll c, ll d) : sum(a), psum(b), ssum(c), max(d) {}
  };
  auto f = [](node l, node r) {
    node ret;
    ret.sum = l.sum + r.sum;
    ret.psum = max(l.psum, l.sum + r.psum);
    ret.ssum = max(r.ssum, l.ssum + r.sum);
    ret.max = max({l.max, r.max, l.ssum + r.psum});
    return ret;
  };
  auto g = [](node l, node r) {
    return r;
  };
  segtree<node> seg(n, f, g, node(0, -LLINF, -LLINF, -LLINF));

  REP(i, n) {
    node tmp(a[i], a[i], a[i], a[i]);
    seg.update(i, tmp);
  }
  REP(i, q) {
    string s;
    cin >> s;
    if(s == "max") {
      ll l1, l2, r1, r2;
      cin >> l1 >> l2 >> r1 >> r2;
      l1--, l2--, r1--, r2--;
      if(r1 < l1) r1 = l1;
      if(r2 < l2) l2 = r2;
      if(l2 < r1) {
        cout << seg.query(l1, l2+1).ssum + seg.query(l2+1, r1).sum + seg.query(r1, r2+1).psum << endl;
      } else {
        node left = seg.query(l1, r1),
             mid = seg.query(r1, l2+1),
             right = seg.query(l2+1, r2+1);
        cout << max({left.ssum + mid.sum + right.psum,
                     left.ssum + mid.psum,
                     mid.ssum + right.psum,
                     mid.max}) << endl;
      }
    } else {
      ll idx, val;
      cin >> idx >> val; idx--;
      node tmp(val, val, val, val);
      seg.update(idx, tmp);
    }
  }

  return 0;
}