Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ret2home/Library

:heavy_check_mark: FM Index
(string/FM_index.cpp)

概要

高速な全文検索を提供する FM Indexの実装。JOI 夏季セミナー 2020 の成果物。

Suffix Arrayと同じ。

計算量

Suffix Arrayはクエリが $O(\mid T \mid log \mid S \mid)$ なのに対し、こちらの方が速いのが特長。

ただし、Suffix Arrayの方のクエリは定数倍が軽いので実際はあんまり変わらない事もある。

Depends on

Verified with

Code

#pragma once
#include "../structure/WaveletMatrix.cpp"
#include "../template/template.cpp"
#include "BWT.cpp"

template <class T, class C>
class FMIndex {
    ll N, base;
    T bwt;
    vector<ll> c;
    WaveletMatrix<T, C> WM;
    SuffixArray<T> SA;

   public:
    T ST;
    PL occ(T &S) {
        for (auto i : S)
            if ((ll)i < base || (ll)i - base >= len(c)) return PL(0, 0);
        ll sp = 0, ep = N;
        rev(i, len(S)) {
            sp = c[(ll)S[i] - base] + WM.rank(S[i], sp);
            ep = c[(ll)S[i] - base] + WM.rank(S[i], ep);
            if (sp >= ep) return PL(0, 0);
        }
        return PL(sp, ep);
    }
    vector<ll> locate(T &S) {
        vector<bool> v(len(ST) + 1);
        PL range = occ(S);
        for (ll i = range.first; i < range.second; i++) v[SA[i]] = true;
        vector<ll> res;
        rep(i, len(ST) + 1) if (v[i]) res.emplace_back(i);
        return res;
    }
    FMIndex(T S) : N(len(S) + 1), ST(S + '$'), WM("", 0), SA(S) {
        bwt = BWT(S, SA);
        WM = WaveletMatrix<T, C>(bwt, 8);
        ll mn = inf, mx = -inf;
        for (C i : ST) {
            chmin(mn, (ll)i);
            chmax(mx, (ll)i);
        }
        c.resize(mx - mn + 2);
        for (C i : ST) {
            c[(ll)i - mn + 1]++;
        }
        rep(i, mx - mn + 1) c[i + 1] += c[i];
        base = mn;
    }
};
/*
@brief FM Index
@docs docs/FM_index.md
*/
#line 2 "template/template.cpp"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rep(i, n) for (int i = 0; i < n; i++)
#define REP(i, n) for (int i = 1; i < n; i++)
#define rev(i, n) for (int i = n - 1; i >= 0; i--)
#define REV(i, n) for (int i = n - 1; i > 0; i--)
#define all(v) v.begin(), v.end()
#define PL pair<ll, ll>
#define PI pair<int, int>
#define pi acos(-1)
#define len(s) (int)s.size()
#define compress(v) \
    sort(all(v));   \
    v.erase(unique(all(v)), v.end());
#define comid(v, x) lower_bound(all(v), x) - v.begin()

template<class T>
using prique=priority_queue<T,vector<T>,greater<>>;

template <class T, class U>
inline bool chmin(T &a, U b) {
    if (a > b) {
        a = b;
        return true;
    }
    return false;
}
template <class T, class U>
inline bool chmax(T &a, U b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}
constexpr ll inf = 3e18;
#line 3 "structure/BitVector.cpp"

class BitVector {
    vector<ll> sum;
    vector<uint64_t> bit;

   public:
    ll rank(bool val, ll idx) {
        uint64_t mask = ((uint64_t)1 << (idx & ((1 << 6) - 1))) - 1;
        ll res = sum[idx >> 6] + __builtin_popcountll(bit[idx >> 6] & mask);
        return (val ? res : idx - res);
    }
    BitVector(vector<bool>& v) {
        ll sz = (len(v) >> 6) + 1;
        bit.assign(sz, 0);
        sum.assign(sz, 0);
        rep(i, len(v)) {
            bit[i >> 6] |= (uint64_t)(v[i]) << (i & ((1 << 6) - 1));
        }
        rep(i, sz - 1) {
            sum[i + 1] = sum[i] + __builtin_popcountll(bit[i]);
        }
    }
};
/*
@brief Bit Vector
@docs docs/BitVector.md
*/
#line 4 "structure/WaveletMatrix.cpp"

template <class T, class C>
class WaveletMatrix {
    int N, bitlen;
    vector<BitVector> index;
    map<C, int> st;

   public:
    T body;
    int rank(C c, int idx) {
        if (st.find(c) == st.end()) return 0;
        rev(i, bitlen) {
            if (c >> i & 1)
                idx = index[i].rank(1, idx) + index[i].rank(0, N);
            else
                idx -= index[i].rank(1, idx);
        }
        return max(0, idx - st[c]);
    }
    C quantile(int l, int r, int c) {
        C res = 0;
        rev(i, bitlen) {
            ll cnt = (r - l) - (index[i].rank(1, r) - index[i].rank(1, l));
            if (cnt <= c) {
                c -= cnt;
                l = index[i].rank(0, N) + index[i].rank(1, l);
                r = index[i].rank(0, N) + index[i].rank(1, r);
                res += 1ll << i;
            } else {
                l -= index[i].rank(1, l);
                r -= index[i].rank(1, r);
            }
        }
        return res;
    }
    WaveletMatrix(T V, ll bitlen) : N(len(V)), bitlen(bitlen), body(V) {
        vector<bool> bit(N);
        index.resize(bitlen, bit);
        rev(i, bitlen) {
            T newV[2];
            rep(j, N) {
                bit[j] = (V[j] >> i & 1);
                newV[V[j] >> i & 1].push_back(V[j]);
            }
            V = newV[0];
            V.insert(V.end(), all(newV[1]));
            index[i] = BitVector(bit);
        }
        rep(i, N) if (st.find(V[i]) == st.end()) st[V[i]] = i;
    }
};
/*
@brief Wavelet Matrix
@docs docs/WaveletMatrix.md
*/
#line 3 "string/SuffixArray.cpp"

template <class T>
class SuffixArray {
#define typeS make_pair(false, false)
#define LMS make_pair(false, true)
#define typeL make_pair(true, true)
    using TYPE = pair<bool, bool>;
    vector<TYPE> assignType(vector<ll> &S) {
        vector<TYPE> type(len(S));
        type[len(S) - 1] = LMS;
        for (ll i = len(S) - 2; i >= 0; i--) {
            if (S[i] < S[i + 1])
                type[i] = typeS;
            else if (S[i] > S[i + 1]) {
                type[i] = typeL;
                if (type[i + 1] == typeS) type[i + 1] = LMS;
            } else
                type[i] = type[i + 1];
        }
        return type;
    }
    vector<ll> getBucket(vector<ll> &S, ll alph) {
        vector<ll> bucket(alph);
        for (ll i : S) bucket[i]++;
        rep(i, len(bucket) - 1) bucket[i + 1] += bucket[i];
        return bucket;
    }
    void sortTypeL(vector<ll> &S, vector<ll> &SA, vector<TYPE> &type, ll alph) {
        vector<ll> bucket = getBucket(S, alph);
        for (ll i : SA) {
            if (i > 0 && type[i - 1] == typeL) SA[bucket[S[i - 1] - 1]++] = i - 1;
        }
    }
    void sortTypeS(vector<ll> &S, vector<ll> &SA, vector<TYPE> &type, ll alph) {
        vector<ll> bucket = getBucket(S, alph);
        rev(j, len(S)) {
            ll i = SA[j];
            if (i > 0 && (type[i - 1] == typeS || type[i - 1] == LMS)) SA[--bucket[S[i - 1]]] = i - 1;
        }
    }
    vector<ll> InducedSorting(vector<ll> &S, ll alph) {
        vector<ll> SA(len(S), -1);
        vector<TYPE> type = assignType(S);
        vector<ll> bucket = getBucket(S, alph);
        vector<ll> nextlms(len(S), -1), ordered_lms;
        ll lastlms = -1;
        rep(i, len(S)) if (type[i] == LMS) {
            SA[--bucket[S[i]]] = i;
            if (lastlms != -1) nextlms[lastlms] = i;
            lastlms = i;
            ordered_lms.emplace_back(i);
        }
        nextlms[lastlms] = lastlms;
        sortTypeL(S, SA, type, alph);
        sortTypeS(S, SA, type, alph);
        vector<ll> lmses;
        for (ll i : SA)
            if (type[i] == LMS) lmses.emplace_back(i);
        ll nowrank = 0;
        vector<ll> newS = {0};
        REP(i, len(lmses)) {
            ll pre = lmses[i - 1], now = lmses[i];
            if (nextlms[pre] - pre != nextlms[now] - now)
                newS.emplace_back(++nowrank);
            else {
                bool flag = false;
                rep(j, nextlms[pre] - pre + 1) {
                    if (S[pre + j] != S[now + j]) {
                        flag = true;
                        break;
                    }
                }
                if (flag)
                    newS.emplace_back(++nowrank);
                else
                    newS.emplace_back(nowrank);
            }
        }
        if (nowrank + 1 != len(lmses)) {
            vector<ll> V(len(S), -1);
            rep(i, len(lmses)) {
                V[lmses[i]] = newS[i];
            }
            vector<ll> newnewS;
            rep(i, len(S)) if (V[i] != -1) newnewS.emplace_back(V[i]);
            vector<ll> SA_ = InducedSorting(newnewS, nowrank + 1);
            rep(i, len(SA_)) {
                lmses[i] = ordered_lms[SA_[i]];
            }
        }
        SA.assign(len(S), -1);
        bucket = getBucket(S, alph);
        rev(i, len(lmses)) {
            SA[--bucket[S[lmses[i]]]] = lmses[i];
        }
        sortTypeL(S, SA, type, alph);
        sortTypeS(S, SA, type, alph);
        return SA;
    }
    vector<ll> SA;
    T ST;

   private:
    ll ismatch(T &S, ll index) {
        rep(i, len(S)) {
            if (i + index >= len(ST)) return 1;
            if (ST[i + index] < S[i]) return 1;
            if (ST[i + index] > S[i]) return -1;
        }
        return 0;
    }

   public:
    PL occ(T &S) {
        ll okl = len(ST) + 1, ngl = 0;
        while (okl - ngl > 1) {
            ll mid = (okl + ngl) / 2;
            if (ismatch(S, SA[mid]) <= 0)
                okl = mid;
            else
                ngl = mid;
        }
        ll okr = len(ST) + 1, ngr = 0;
        while (okr - ngr > 1) {
            ll mid = (okr + ngr) / 2;
            if (ismatch(S, SA[mid]) < 0)
                okr = mid;
            else
                ngr = mid;
        }
        return PL(okl, okr);
    }
    vector<ll> locate(T &S) {
        vector<bool> v(len(ST) + 1);
        PL range = occ(S);
        for (ll i = range.first; i < range.second; i++) v[SA[i]] = true;
        vector<ll> res;
        rep(i, len(ST) + 1) if (v[i]) res.emplace_back(i);
        return res;
    }
    ll operator[](ll k) { return SA[k]; }

   public:
    vector<ll> LCP;

   private:
    void constructLCP() {
        vector<ll> rank(len(ST) + 1);
        LCP.resize(len(ST) + 1);
        rep(i, len(ST) + 1) rank[SA[i]] = i;
        ll h = 0;
        rep(i, len(ST)) {
            ll j = SA[rank[i] - 1];
            if (h > 0) h--;
            for (j; j + h < len(ST) && i + h < len(ST); h++) {
                if (ST[j + h] != ST[i + h]) break;
            }
            LCP[rank[i] - 1] = h;
        }
    }

   public:
    SuffixArray(T S) : ST(S) {
        ll mn = inf, mx = -inf;
        for (auto i : S) {
            chmin(mn, (ll)i);
            chmax(mx, (ll)i);
        }
        vector<ll> newS;
        for (auto i : S) newS.emplace_back(i - mn + 1);
        newS.emplace_back(0);
        SA = InducedSorting(newS, mx - mn + 2);
        constructLCP();
    }
};

/*
@brief Suffix Array (SA-IS)
@docs docs/SuffixArray.md
*/
#line 4 "string/BWT.cpp"

template <class T>
T BWT(T S, SuffixArray<T>& SA) {
    S += '$';
    T bwt;
    rep(i, len(S)) {
        bwt.push_back(S[(SA[i] - 1 + len(S)) % len(S)]);
    }
    return bwt;
}

template <class T>
T BWTInverse(T S) {
    vector<ll> B(len(S));
    ll mx = -inf;
    rep(i, len(S)) {
        B[i] = (S[i] == '$' ? 0 : (ll)S[i]);
        chmax(mx, B[i]);
    }
    vector<vector<ll>> BB(mx + 1), F(mx + 1);
    vector<PL> V;
    rep(i, len(S)) {
        BB[B[i]].push_back(i);
        F[B[i]].push_back(i);
    }
    ll cnt = 0;
    rep(i, mx + 1) {
        rep(j, len(F[i])) {
            F[i][j] = cnt++;
            V.push_back({i, j});
        }
    }
    ll now = BB[0][0];
    T res;
    rep(i, len(S) - 1) {
        res.push_back(V[now].first);
        now = BB[V[now].first][V[now].second];
    }
    return res;
}
/*
@brief Burrows Wheeler Transform
@docs docs/BWT.md
*/
#line 5 "string/FM_index.cpp"

template <class T, class C>
class FMIndex {
    ll N, base;
    T bwt;
    vector<ll> c;
    WaveletMatrix<T, C> WM;
    SuffixArray<T> SA;

   public:
    T ST;
    PL occ(T &S) {
        for (auto i : S)
            if ((ll)i < base || (ll)i - base >= len(c)) return PL(0, 0);
        ll sp = 0, ep = N;
        rev(i, len(S)) {
            sp = c[(ll)S[i] - base] + WM.rank(S[i], sp);
            ep = c[(ll)S[i] - base] + WM.rank(S[i], ep);
            if (sp >= ep) return PL(0, 0);
        }
        return PL(sp, ep);
    }
    vector<ll> locate(T &S) {
        vector<bool> v(len(ST) + 1);
        PL range = occ(S);
        for (ll i = range.first; i < range.second; i++) v[SA[i]] = true;
        vector<ll> res;
        rep(i, len(ST) + 1) if (v[i]) res.emplace_back(i);
        return res;
    }
    FMIndex(T S) : N(len(S) + 1), ST(S + '$'), WM("", 0), SA(S) {
        bwt = BWT(S, SA);
        WM = WaveletMatrix<T, C>(bwt, 8);
        ll mn = inf, mx = -inf;
        for (C i : ST) {
            chmin(mn, (ll)i);
            chmax(mx, (ll)i);
        }
        c.resize(mx - mn + 2);
        for (C i : ST) {
            c[(ll)i - mn + 1]++;
        }
        rep(i, mx - mn + 1) c[i + 1] += c[i];
        base = mn;
    }
};
/*
@brief FM Index
@docs docs/FM_index.md
*/
Back to top page