AtCoder Regular Contest 098

D - Xor Sum 2

  • 問題概要

長さnの数列がある。
A[l] xor A[l+1] xor ... xor A[r] = A[l] +A[l+1] + ... + A[r]となるl,r(1<=l<=r<=n)の組の数を求めよ。

  • 解法

l,rを自由に動かすとO(N^2)でTLE。
ここで 各bitで a xor b <= a + b であることに注目すると、lを固定したとき組(l,r)が条件をみたすならばr=l,l+1,...,rも条件を満たすことがわかる。
よって各lについて二分探索か尺取法でrの最大値を求めていけばいい。
xor も + も Monoidなので、区間xorと区間sumはセグメント木を使ってO(logN)で求められる。
二分探索で求めていくと計算量はO(N(logN)^2)で、厳し目だが一応通る(以下のコード)


冷静になってみると、区間xorは累積xorを、区間sumは累積和を用いることでO(1)で求められる。
xorの問題はbitごとに考えるという基本が何もできてなかったのと、segment木の理解が足りず手間取ってしまった。
本番では累積xorとBITを使った結果、前者がクエリに対して(l,r]、後者が[l,r)であることで混乱してバグを埋め込むことに。

#include <bits/stdc++.h> 

using namespace std;
typedef long long ll;

#define rep(i,n) for(int (i)=0;(i)<(n);(i)++)

template< typename Monoid >
struct SegmentTree
{
    using F = function< Monoid(Monoid, Monoid) >;

    int sz;
    vector< Monoid > seg;

    const F f;
    const Monoid M1;

    SegmentTree(int n, const F f, const Monoid &M1) : f(f), M1(M1)
    {
        sz = 1;
        while (sz < n) sz <<= 1;
        seg.assign(2 * sz, M1);
    }

    void set(int k, const Monoid &x)
    {
        seg[k + sz] = x;
    }

    void build()
    {
        for (int k = sz - 1; k > 0; k--) {
            seg[k] = f(seg[2 * k + 0], seg[2 * k + 1]);
        }
    }

    void update(int k, const Monoid &x)
    {
        k += sz;
        seg[k] = x;
        while (k >>= 1) {
            seg[k] = f(seg[2 * k + 0], seg[2 * k + 1]);
        }
    }

    Monoid query(int a, int b)
    {
        Monoid L = M1, R = M1;
        for (a += sz, b += sz; a < b; a >>= 1, b >>= 1) {
            if (a & 1) L = f(L, seg[a++]);
            if (b & 1) R = f(seg[--b], R);
        }
        return f(L, R);
    }

    Monoid operator[](const int &k) const
    {
        return seg[k + sz];
    }
};

const int N = 200010;
int n;
ll ret;
SegmentTree< ll > seg1(N, [](ll a, ll b) { return a + b; }, 0);
SegmentTree< ll > seg2(N, [](ll a, ll b) { return a ^ b; }, 0);

int main() {
    cin >> n;
    rep(i, n) {
        ll x;cin >> x;
        seg1.set(i, x);
        seg2.set(i, x);
    }
    seg1.build();
    seg2.build();
    rep(i, n) {
        ll ok = i, ng = n;
        while (abs(ng - ok) > 1) {
            ll mi = (ng + ok) / 2;
            if (seg1.query(i, mi + 1) == seg2.query(i, mi + 1)) ok = mi;
            else ng = mi;
        }
        ret += ok - i + 1;
    }
    cout << ret << endl;
}

更新なし区間クエリを、構築O(NlogN),クエリO(1)でできる Disjoint Sparse Table というデータ構造があるので、それを使うのもいいかもしれない。disjoint...