ferinの競プロ帳

競プロについてのメモ

RUPC2018 day2 J Matrix

問題ページ

解法

ある矩形範囲(a,b,c,d)の最大値はその区間の max(A)*max(B), max(A)*min(B), min(A)*max(B), min(A)*min(B) のうちのどれかとなる。最小値についても同様にどれかになる。
最大値の個数については max(A)*max(B)が最大値であればmax(A)の個数*max(B)の個数, max(A)*min(B)が最大値であればmax(A)の個数*min(B)の個数,… の和によって求められる。
max(A)==min(A)のケースや答えが0になるケースに注意が必要。maxが0のときについて考えると、行列Cの要素のうち+になるようなものが存在してはいけないのでAとBが取る数の符号が異なる場合のみのはずである。したがって (min(A)>=0 かつ max(b)<=0) もしくは (max(A) <= 0 かつ min(B) >= 0) となっているはずで最大値の候補に上げた4個の積のうち0になるものが必ず存在する。つまりans=0のときも値については上に上げた4つのmaxを取ることで求められる。0になる個数を求めるのは矩形範囲全体から0でないものを引くとすればよい。minについても同様。

したがって区間加算区間max,min,countが可能な遅延セグメントツリーを使えばよい。この記事のように抽象化した遅延セグ木でどう実装したのか書く。(最大,最大の個数,最小,最小の個数)がほしいのでTはvector、Eにはその区間全体に加算する値がほしいのでintとする。マージする写像fでは最大,最小はmax,minを取る、個数は左右のノードを見てmaxとノードの値が一致すればそのノードの個数をプラスするとした。写像gでは区間加算なので最大と最小にだけプラスし個数には何もしない。写像hでは区間加算なので和を返す。単位元d0は(0,1,0,1)、d1は0とする。

まとめると遅延セグメントツリーを使いクエリに答えていく。下の実装はdefine int llしてるので注意。

ソースコード

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
#define int ll
using VI = vector<int>;
using VVI = vector<VI>;

#define FOR(i, a, n) for (ll i = (ll)a; i < (ll)n; ++i)
#define REP(i, n) FOR(i, 0, n)
#define PB push_back

const ll LLINF = (1LL<<60);

// 遅延セグメントツリー
template <typename T, typename E>
struct segtree {
  using F = function<T(T,T)>;
  using G = function<T(T,E)>;
  using H = function<E(E,E)>;
  using P = function<E(E,int)>;
  F f; G g; H h; P p; T d1; E d0;
  int n;
  vector<T> dat;
  vector<E> lazy;

  segtree(){}
  segtree(int n_, F f_, G g_, H h_, T d1_, E d0_, P p_=[](E a, int b){return a;}):
    f(f_), g(g_), h(h_), p(p_), d1(d1_), d0(d0_) {
    n = 1; while(n < n_) n *= 2;
    dat.assign(n*2, d1);
    lazy.assign(n*2, d0);
  }
  void build(vector<T> v) {
    REP(i, v.size()) dat[i+n-1] = v[i];
    for(int i=n-2; i>=0; --i) dat[i] = f(dat[i*2+1], dat[i*2+2]);
  }

  // 区間の幅がlenの節点kについて遅延評価
  inline void eval(int len, int k) {
    if(lazy[k] == d0) return;
    if(k*2+1 < n*2-1) {
      lazy[2*k+1] = h(lazy[k*2+1], lazy[k]);
      lazy[2*k+2] = h(lazy[k*2+2], lazy[k]);
    }
    dat[k] = g(dat[k],p(lazy[k],len));
    lazy[k] = d0;
  }
  // [a, b)
  T update(int a, int b, E x, int k, int l, int r) {
    eval(r-l, k);
    if(b <= l || r <= a) return dat[k];
    if(a <= l && r <= b) {
      lazy[k] = h(lazy[k], x);
      return g(dat[k], p(lazy[k],r-l));
    }
    return dat[k] = f(update(a, b, x, 2*k+1, l, (l+r)/2),
                      update(a, b, x, 2*k+2, (l+r)/2, r));
  }
  T update(int a, int b, E x) { return update(a, b, x, 0, 0, n); }
  // [a, b)
  T query(int a, int b, int k, int l, int r) {
    eval(r-l, k);
    if(a <= l && r <= b) return dat[k];
    bool left = !((l+r)/2 <= a || b <= l), right = !(r <= 1 || b <= (l+r)/2);
    if(left&&right) return f(query(a, b, 2*k+1, l, (l+r)/2), query(a, b, 2*k+2, (l+r)/2, r));
    if(left) return query(a, b, 2*k+1, l, (l+r)/2);
    return query(a, b, 2*k+2, (l+r)/2, r);
  }
  T query(int a, int b) { return query(a, b, 0, 0, n); }
};

signed main(void)
{
  int H, W, q;
  scanf("%lld%lld%lld", &H, &W, &q);
  VI a(H), b(W);
  REP(i, H) scanf("%lld", &a[i]);
  REP(i, W) scanf("%lld", &b[i]);

  // セグ木を定義
  auto f = [](VI a, VI b) {
    // max,minの数
    int maxnum = a[0]==max(a[0],b[0])?a[1]:0;
    maxnum += b[0]==max(a[0],b[0])?b[1]:0;
    int minnum = a[2]==min(a[2],b[2])?a[3]:0;
    minnum += b[2]==min(a[2],b[2])?b[3]:0;
    return VI{
      max(a[0],b[0]),
      maxnum,
      min(a[2],b[2]),
      minnum
    };
  };
  auto g = [](VI a, int b) {
    return VI{
      a[0]+b,
      a[1],
      a[2]+b,
      a[3]
    };
  };
  auto h = [](int a, int b) {
    return a+b;
  };
  segtree<VI,int> seg1(H, f, g, h, {0,1,0,1}, 0),
                 seg2(W, f, g, h, {0,1,0,1}, 0);

  // 初期状態を設定
  VVI inita, initb;
  REP(i, H) inita.PB({a[i], 1, a[i], 1});
  REP(i, W) initb.PB({b[i], 1, b[i], 1});
  seg1.build(inita); seg2.build(initb);

  // クエリにこたえる
  REP(i, q) {
    int type;
    scanf("%lld", &type);
    if(type==1) {
      int a, b, v;
      scanf("%lld%lld%lld", &a, &b, &v);
      seg1.update(a-1, b, v);
    } else if(type==2) {
      int a, b, v;
      scanf("%lld%lld%lld", &a, &b, &v);
      seg2.update(a-1, b, v);
    } else if(type==4) {
      int ma = -LLINF, num = 0;
      int a, b, c, d;
      scanf("%lld%lld%lld%lld", &a, &b, &c, &d);
      VI v1 = seg1.query(a-1, b);
      VI v2 = seg2.query(c-1, d);
      ma = max({
        v1[0]*v2[0],
        v1[2]*v2[2],
        v1[0]*v2[2],
        v1[2]*v2[0]
      });
      // maxが0のケース
      if(ma == 0) {
        int tmpa = 0, tmpb = 0;
        if(v1[0] == 0) tmpa = v1[1];
        if(v1[2] == 0) tmpa = v1[3];
        if(v2[0] == 0) tmpb = v2[1];
        if(v2[2] == 0) tmpb = v2[3];
        num = (b-a+1)*(d-c+1) - (b-a+1-tmpa)*(d-c+1-tmpb);
      } else {
        if(v1[0] == v1[2] && v2[0] == v2[2]) {
          num = (b-a+1) * (d-c+1);
        } else if(v1[0] == v1[2] || v2[0] == v2[2]) {
          if(v1[2]*v2[2] == ma) num += v1[3]*v2[3];
          if(v1[0]*v2[0] == ma) num += v1[1]*v2[1];
        } else {
          if(v1[0]*v2[0] == ma) num += v1[1]*v2[1];
          if(v1[0]*v2[2] == ma) num += v1[1]*v2[3];
          if(v1[2]*v2[0] == ma) num += v1[3]*v2[1];
          if(v1[2]*v2[2] == ma) num += v1[3]*v2[3];
        }
      }
      printf("%lld %lld\n", ma, num);
    } else if(type==3) {
      int mi = LLINF, num = 0;
      int a, b, c, d;
      scanf("%lld%lld%lld%lld", &a, &b, &c, &d);
      VI v1 = seg1.query(a-1, b);
      VI v2 = seg2.query(c-1, d);
      mi = min({
        v1[0]*v2[0],
        v1[2]*v2[2],
        v1[0]*v2[2],
        v1[2]*v2[0]
      });
      // minが0の場合
      if(mi == 0) {
        int tmpa = 0, tmpb = 0;
        if(v1[0] == 0) tmpa = v1[1];
        if(v1[2] == 0) tmpa = v1[3];
        if(v2[0] == 0) tmpb = v2[1];
        if(v2[2] == 0) tmpb = v2[3];
        num = (b-a+1)*(d-c+1) - (b-a+1-tmpa)*(d-c+1-tmpb);
      } else {
        if(v1[0] == v1[2] && v2[0] == v2[2]) {
          num = (b-a+1) * (d-c+1);
        } else if(v1[0] == v1[2] || v2[0] == v2[2]) {
          if(v1[2]*v2[2] == mi) num += v1[3]*v2[3];
          if(v1[0]*v2[0] == mi) num += v1[1]*v2[1];
        } else {
          if(v1[0]*v2[0] == mi) num += v1[1]*v2[1];
          if(v1[0]*v2[2] == mi) num += v1[1]*v2[3];
          if(v1[2]*v2[0] == mi) num += v1[3]*v2[1];
          if(v1[2]*v2[2] == mi) num += v1[3]*v2[3];
        }
      }
      printf("%lld %lld\n", mi, num);
    }
  }

  return 0;
}