ACM_Notebook_new

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ngthanhtrung23/ACM_Notebook_new

:heavy_check_mark: DataStructure/test/splay_tree.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_sequence_range_affine_range_sum"

#include <bits/stdc++.h>
using namespace std;

#include "../../Math/modint.h"
#include "../splay_tree.h"

using modular = ModInt<998244353>;

struct S {
    long long sum, sz;
};
struct F {
    long long a, b;
};
using Node = node_t<int, S, F>;

const int MOD = 998244353;
S op(S left, int key, S right) {
    return S {
        (left.sum + key + right.sum) % MOD,
        left.sz + 1 + right.sz,
    };
}
pair<int, S> e() {
    return {0, {0, 0}};
}
pair<int, S> mapping(F f, Node* node) {
    return {
        (f.a * node->key + f.b) % MOD,
        S {
            (f.a * node->data.sum + f.b * node->data.sz) % MOD,
            node->data.sz,
        }
    };
}
F composition(F f, F g) {
    return F {
        f.a * g.a % MOD,
        (f.a * g.b + f.b) % MOD,
    };
}
F id() {
    return F {1, 0};
}

int32_t main() {
    ios::sync_with_stdio(0); cin.tie(0);
    int n, q; cin >> n >> q;
    vector<int> keys(n);
    for (auto& key : keys) cin >> key;

    SplayTreeById<
        int,
        S,
        op,
        e,
        F,
        mapping,
        composition,
        id
    > tree(keys);

    while (q--) {
        int typ; cin >> typ;
        if (typ == 0) {
            int pos, val; cin >> pos >> val;
            tree.insert(pos, val);
        } else if (typ == 1) {
            int pos; cin >> pos;
            tree.erase(pos);
        } else if (typ == 2) {
            int l, r; cin >> l >> r;
            tree.reverse(l, r);
        } else if (typ == 3) {
            int l, r, a, b; cin >> l >> r >> a >> b;
            tree.apply(l, r, F{a, b});
        } else {
            assert(typ == 4);
            int l, r; cin >> l >> r;
            printf("%lld\n", tree.prod(l, r).sum);
        }
    }
    return 0;
}
#line 1 "DataStructure/test/splay_tree.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_sequence_range_affine_range_sum"

#include <bits/stdc++.h>
using namespace std;

#line 1 "Math/modint.h"
// ModInt {{{
template<int MD> struct ModInt {
    using ll = long long;
    int x;

    constexpr ModInt() : x(0) {}
    constexpr ModInt(ll v) { _set(v % MD + MD); }
    constexpr static int mod() { return MD; }
    constexpr explicit operator bool() const { return x != 0; }

    constexpr ModInt operator + (const ModInt& a) const {
        return ModInt()._set((ll) x + a.x);
    }
    constexpr ModInt operator - (const ModInt& a) const {
        return ModInt()._set((ll) x - a.x + MD);
    }
    constexpr ModInt operator * (const ModInt& a) const {
        return ModInt()._set((ll) x * a.x % MD);
    }
    constexpr ModInt operator / (const ModInt& a) const {
        return ModInt()._set((ll) x * a.inv().x % MD);
    }
    constexpr ModInt operator - () const {
        return ModInt()._set(MD - x);
    }

    constexpr ModInt& operator += (const ModInt& a) { return *this = *this + a; }
    constexpr ModInt& operator -= (const ModInt& a) { return *this = *this - a; }
    constexpr ModInt& operator *= (const ModInt& a) { return *this = *this * a; }
    constexpr ModInt& operator /= (const ModInt& a) { return *this = *this / a; }

    friend constexpr ModInt operator + (ll a, const ModInt& b) {
        return ModInt()._set(a % MD + b.x);
    }
    friend constexpr ModInt operator - (ll a, const ModInt& b) {
        return ModInt()._set(a % MD - b.x + MD);
    }
    friend constexpr ModInt operator * (ll a, const ModInt& b) {
        return ModInt()._set(a % MD * b.x % MD);
    }
    friend constexpr ModInt operator / (ll a, const ModInt& b) {
        return ModInt()._set(a % MD * b.inv().x % MD);
    }

    constexpr bool operator == (const ModInt& a) const { return x == a.x; }
    constexpr bool operator != (const ModInt& a) const { return x != a.x; }

    friend std::istream& operator >> (std::istream& is, ModInt& other) {
        ll val; is >> val;
        other = ModInt(val);
        return is;
    }
    constexpr friend std::ostream& operator << (std::ostream& os, const ModInt& other) {
        return os << other.x;
    }

    constexpr ModInt pow(ll k) const {
        ModInt ans = 1, tmp = x;
        while (k) {
            if (k & 1) ans *= tmp;
            tmp *= tmp;
            k >>= 1;
        }
        return ans;
    }

    constexpr ModInt inv() const {
        if (x < 1000111) {
            _precalc(1000111);
            return invs[x];
        }
        int a = x, b = MD, ax = 1, bx = 0;
        while (b) {
            int q = a/b, t = a%b;
            a = b; b = t;
            t = ax - bx*q;
            ax = bx; bx = t;
        }
        assert(a == 1);
        if (ax < 0) ax += MD;
        return ax;
    }

    static std::vector<ModInt> factorials, inv_factorials, invs;
    constexpr static void _precalc(int n) {
        if (factorials.empty()) {
            factorials = {1};
            inv_factorials = {1};
            invs = {0};
        }
        if (n > MD) n = MD;
        int old_sz = factorials.size();
        if (n <= old_sz) return;

        factorials.resize(n);
        inv_factorials.resize(n);
        invs.resize(n);

        for (int i = old_sz; i < n; ++i) factorials[i] = factorials[i-1] * i;
        inv_factorials[n-1] = factorials.back().pow(MD - 2);
        for (int i = n - 2; i >= old_sz; --i) inv_factorials[i] = inv_factorials[i+1] * (i+1);
        for (int i = n-1; i >= old_sz; --i) invs[i] = inv_factorials[i] * factorials[i-1];
    }

    static int get_primitive_root() {
        static int primitive_root = 0;
        if (!primitive_root) {
            primitive_root = [&]() {
                std::set<int> fac;
                int v = MD - 1;
                for (ll i = 2; i * i <= v; i++)
                    while (v % i == 0) fac.insert(i), v /= i;
                if (v > 1) fac.insert(v);
                for (int g = 1; g < MD; g++) {
                    bool ok = true;
                    for (auto i : fac)
                        if (ModInt(g).pow((MD - 1) / i) == 1) {
                            ok = false;
                            break;
                        }
                    if (ok) return g;
                }
                return -1;
            }();
        }
        return primitive_root;
    }

    static ModInt C(int n, int k) {
        _precalc(n + 1);
        return factorials[n] * inv_factorials[k] * inv_factorials[n-k];
    }
    
private:
    // Internal, DO NOT USE.
    // val must be in [0, 2*MD)
    constexpr inline __attribute__((always_inline)) ModInt& _set(ll v) {
        x = v >= MD ? v - MD : v;
        return *this;
    }
};
template <int MD> std::vector<ModInt<MD>> ModInt<MD>::factorials = {1};
template <int MD> std::vector<ModInt<MD>> ModInt<MD>::inv_factorials = {1};
template <int MD> std::vector<ModInt<MD>> ModInt<MD>::invs = {0};
// }}}
#line 1 "DataStructure/splay_tree.h"
// SplayTreeById
//
// Note:
// - op() must be commutative, otherwise reverse queries won't work.
//   To fix it, need to store aggregate data from right->left
//   See https://judge.yosupo.jp/submission/53778 (and look at invsum)
//
// Tested:
// - (cut, join)      https://vn.spoj.com/problems/CARDS/
// - (keys, reverse)  https://oj.vnoi.info/problem/twist
// - (insert, prod)   https://oj.vnoi.info/problem/qmax3vn
// - (insert, delete) https://vn.spoj.com/problems/QMAX4/
// - (insert, delete) https://vn.spoj.com/problems/CARDSHUF/
// - (lazy)           https://judge.yosupo.jp/problem/dynamic_sequence_range_affine_range_sum
// - (lazy)           https://oj.vnoi.info/problem/upit
template<class K, class S, class F>
struct node_t {
    using Node = node_t<K, S, F>;

    std::array<Node*, 2> child;
    Node *father;
    int size;
    
    // Whether we will need to reverse this subtree.
    // Handling reverse operations requires some specialized code,
    // so I couldn't put this in F
    bool reverse;

    K key;
    S data;
    F lazy;
};
template<
    class K,                               // key
    class S,                               // node aggregate data
    S (*op) (S, K, S),                     // for recomputing data of a node
    pair<K, S> (*e) (),                    // identity data
    class F,                               // lazy propagation tag
    pair<K, S> (*mapping) (F, node_t<K, S, F>*),  // apply tag F on a node
    F (*composition) (F, F),               // combine 2 tags
    F (*id)()                              // identity tag
>
struct SplayTreeById {
    using Node = node_t<K, S, F>;

    Node *nil, *root;

    SplayTreeById() {
        initNil();
        root = nil;
    }
    SplayTreeById(const vector<K>& keys) {
        initNil();
        root = createNode(keys, 0, (int) keys.size());
    }

    vector<K> getKeys() {
        vector<K> keys;
        traverse(root, keys);
        return keys;
    }

    // k in [0, n-1]
    Node* kth(int k) {
        auto res = _kth(root, k);
        splay(res);
        root = res;
        return res;
    }

    // Return <L, R>:
    // - L contains [0, k-1]
    // - R contains [k, N-1]
    // Modify tree
    pair<Node*, Node*> cut(int k) {
        if (k == 0) {
            return {nil, root};
        } else if (k == root->size) {
            return {root, nil};
        } else {
            Node *left = kth(k - 1);  // kth already splayed
            Node* right = left->child[1];
            left->child[1] = right->father = nil;
            pushUp(left);
            return {left, right};
        }
    }

    // Return <X, Y, Z>:
    // - X contains [0, u-1]
    // - Y contains [u, v-1]
    // - Z contains [v, N-1]
    // This is useful for queries on [u, v-1]
    // Modify tree
    tuple<Node*, Node*, Node*> cut(int u, int v) {
        auto [xy, z] = cut(v);
        root = xy;
        auto [x, y] = cut(u);
        return {x, y, z};
    }

    // Make this tree x + y
    void join(Node *x, Node *y) {
        if (x == nil) {
            root = y;
            return;
        }
        while (1) {
            pushDown(x);
            if (x->child[1] == nil) break;
            x = x->child[1];
        }
        splay(x);
        setChild(x, y, 1);
        pushUp(x);
        root = x;
    }

    // reverse range [u, v-1]
    void reverse(int u, int v) {
        assert(0 <= u && u <= v && v <= root->size);
        if (u == v) return;

        auto [x, y, z] = cut(u, v);
        y->reverse = true;
        join(x, y);
        join(root, z);
    }

    // apply F on range [u, v-1]
    void apply(int u, int v, const F& f) {
        assert(0 <= u && u <= v && v <= root->size);
        if (u == v) return;

        auto [x, y, z] = cut(u, v);
        y->lazy = composition(f, y->lazy);
        std::tie(y->key, y->data) = mapping(f, y);

        join(x, y);
        join(root, z);
    }

    // Insert before pos
    // pos in [0, N]
    void insert(int pos, K key) {
        assert(0 <= pos && pos <= root->size);
        // x: [0, pos-1]
        // y: [pos, N-1]
        auto [x, y] = cut(pos);
        auto node = createNode(key);
        setChild(node, x, 0);
        setChild(node, y, 1);
        pushUp(node);
        root = node;
    }

    // Delete pos; pos in [0, N-1]
    K erase(int pos) {
        assert(0 <= pos && pos < root->size);

        // x = [0, pos-1]
        // y = [pos, pos]
        // z = [pos+1, N-1]
        auto [x, y, z] = cut(pos, pos+1);
        join(x, z);
        return y->key;
    }

    // aggregated data of range [l, r-1]
    S prod(int l, int r) {
        auto [x, y, z] = cut(l, r);
        auto res = y->data;
        join(x, y);
        join(root, z);
        return res;
    }

// private:
    void initNil() {
        nil = new Node();
        nil->child[0] = nil->child[1] = nil->father = nil;
        nil->size = 0;
        nil->reverse = false;
        std::tie(nil->key, nil->data) = e();
        nil->lazy = id();
    }
    void pushUp(Node* x) {
        if (x == nil) return;
        x->size = x->child[0]->size + x->child[1]->size + 1;
        x->data = op(x->child[0]->data, x->key, x->child[1]->data);
    }
    void pushDown(Node* x) {
        if (x == nil) return;

        if (x->reverse) {
            for (auto c : x->child) {
                if (c != nil) {
                    c->reverse ^= 1;
                }
            }
            std::swap(x->child[0], x->child[1]);
            x->reverse = false;
        }

        for (auto c : x->child) {
            if (c != nil) {
                std::tie(c->key, c->data) = mapping(x->lazy, c);
                c->lazy = composition(x->lazy, c->lazy);
            }
            // For problem like UPIT, where we want to push different
            // lazy tags to left & right children, may need to modify
            // code here
            // (query L R X: a(L) += X, a(L+1) += 2X, ...)
            // e.g. for UPIT:
            // x->lazy.add_left += (1 + c->size) * x->lazy.step;
        }

        x->lazy = id();
    }
    Node* createNode(K key) {
        Node *res = new Node();
        res->child[0] = res->child[1] = res->father = nil;
        res->key = key;
        res->size = 1;
        res->data = e().second;
        res->lazy = id();
        return res;
    }
    void setChild(Node *x, Node *y, int d) {
        x->child[d] = y;
        if (y != nil) y->father = x;
    }
    // Assumption: x is father of y
    int getDirection(Node *x, Node *y) {
        assert(y->father == x);
        return x->child[0] == y ? 0 : 1;
    }
    // create subtree from keys[l, r-1]
    Node* createNode(const vector<K>& keys, int l, int r) {
        if (l >= r) {  // empty
            return nil;
        }
        int mid = (l + r) / 2;
        Node *p = createNode(keys[mid]);
        Node *left = createNode(keys, l, mid);
        Node *right = createNode(keys, mid + 1, r);

        setChild(p, left, 0);
        setChild(p, right, 1);

        pushUp(p);
        return p;
    }
    void traverse(Node* x, vector<K>& keys) {
        if (x == nil) return;
        pushDown(x);
        traverse(x->child[0], keys);
        keys.push_back(x->key);
        traverse(x->child[1], keys);
    }
    /**
     * Before:
     *    y
     *    |
     *    x
     *  /
     * z
     *  \
     *  zchild
     *
     * After:
     *    y
     *    |
     *    z
     *     \
     *      x
     *     /
     *  zchild
     */
    void rotate(Node *x, int d) {
        Node *y = x->father;
        Node *z = x->child[d];
        setChild(x, z->child[d ^ 1], d);
        setChild(y, z, getDirection(y, x));
        setChild(z, x, d ^ 1);
        pushUp(x);
        pushUp(z);
    }
    // Make x root of tree
    Node *splay(Node *x) {
        if (x == nil) return nil;
        while (x->father != nil) {
            Node *y = x->father;
            Node *z = y->father;
            int dy = getDirection(y, x);
            int dz = getDirection(z, y);
            if (z == nil) {
                rotate(y, dy);
            } else if (dy == dz) {
                rotate(z, dz);
                rotate(y, dy);
            } else {
                rotate(y, dy);
                rotate(z, dz);
            }
        }
        return x;
    }

    Node* _kth(Node* p, int k) {
        pushDown(p);
        // left: [0, left->size - 1]
        if (k < p->child[0]->size) {
            return _kth(p->child[0], k);
        }
        k -= p->child[0]->size;
        if (!k) return p;
        return _kth(p->child[1], k - 1);
    }
};

////////// Below: example usage
// Splay tree only need to store keys (no aggregated value / no lazy update)
struct KeyOnlyOps {
    struct S{};
    struct F{};
    using Node = node_t<int, S, F>;
    
    static S op(__attribute__((unused)) S left, __attribute__((unused)) int key, __attribute__((unused)) S right) {
        return {};
    }
    static pair<int, S> e() {
        return {-1, {}};
    }
    static pair<int, S> mapping(__attribute__((unused)) F f, Node* node) {
        return {node->key, {}};
    }
    static F composition(__attribute__((unused)) F f, __attribute__((unused)) F g) {
        return {};
    }
    static F id() {
        return {};
    }
};

/* Example:
    SplayTreeById<
        int,
        KeyOnlyOps::S,
        KeyOnlyOps::op,
        KeyOnlyOps::e,
        KeyOnlyOps::F,
        KeyOnlyOps::mapping,
        KeyOnlyOps::composition,
        KeyOnlyOps::id
    > tree(keys);
 */


// For query get max of keys in range
// No lazy update tags
struct MaxQueryOps {
    static const int INF = 1e9 + 11;
    struct F{};
    using Node = node_t<int, int, F>;

    static int op(const int& left, int key, const int& right) {
        return max({left, key, right});
    }
    static pair<int, int> e() {
        return {-1, -INF};
    }
    static pair<int, int> mapping(__attribute__((unused)) const F& f, Node* node) {
        return {node->key, node->data};
    }
    static F composition(__attribute__((unused)) const F& f, __attribute__((unused)) const F& g) {
        return {};
    }
    static F id() {
        return {};
    }
};
/* Example:
    SplayTreeById<
        int,
        int,
        MaxQueryOps::op,
        MaxQueryOps::e,
        MaxQueryOps::F,
        MaxQueryOps::mapping,
        MaxQueryOps::composition,
        MaxQueryOps::id
    > tree;
 */

// For queries a[i] <- a[i]*mult + add
struct RangeAffineOps {
    struct S {
        long long sum, sz;
    };
    struct F {
        long long a, b;
    };
    using Node = node_t<int, S, F>;

    static const int MOD = 998244353;
    static S op(const S& left, int key, const S& right) {
        return S {
            (left.sum + key + right.sum) % MOD,
            left.sz + 1 + right.sz,
        };
    }
    static pair<int, S> e() {
        return {0, {0, 0}};
    }
    static pair<int, S> mapping(const F& f, Node* node) {
        return {
            (f.a * node->key + f.b) % MOD,
            S {
                (f.a * node->data.sum + f.b * node->data.sz) % MOD,
                node->data.sz,
            }
        };
    }
    static F composition(const F&f, const F& g) {
        return F {
            f.a * g.a % MOD,
            (f.a * g.b + f.b) % MOD,
        };
    }
    static F id() {
        return F {1, 0};
    }
};

/* Example
    SplayTreeById<
        int,
        RangeAffineOps::S,
        RangeAffineOps::op,
        RangeAffineOps::e,
        RangeAffineOps::F,
        RangeAffineOps::mapping,
        RangeAffineOps::composition,
        RangeAffineOps::id
    > tree(keys);
 */
#line 8 "DataStructure/test/splay_tree.test.cpp"

using modular = ModInt<998244353>;

struct S {
    long long sum, sz;
};
struct F {
    long long a, b;
};
using Node = node_t<int, S, F>;

const int MOD = 998244353;
S op(S left, int key, S right) {
    return S {
        (left.sum + key + right.sum) % MOD,
        left.sz + 1 + right.sz,
    };
}
pair<int, S> e() {
    return {0, {0, 0}};
}
pair<int, S> mapping(F f, Node* node) {
    return {
        (f.a * node->key + f.b) % MOD,
        S {
            (f.a * node->data.sum + f.b * node->data.sz) % MOD,
            node->data.sz,
        }
    };
}
F composition(F f, F g) {
    return F {
        f.a * g.a % MOD,
        (f.a * g.b + f.b) % MOD,
    };
}
F id() {
    return F {1, 0};
}

int32_t main() {
    ios::sync_with_stdio(0); cin.tie(0);
    int n, q; cin >> n >> q;
    vector<int> keys(n);
    for (auto& key : keys) cin >> key;

    SplayTreeById<
        int,
        S,
        op,
        e,
        F,
        mapping,
        composition,
        id
    > tree(keys);

    while (q--) {
        int typ; cin >> typ;
        if (typ == 0) {
            int pos, val; cin >> pos >> val;
            tree.insert(pos, val);
        } else if (typ == 1) {
            int pos; cin >> pos;
            tree.erase(pos);
        } else if (typ == 2) {
            int l, r; cin >> l >> r;
            tree.reverse(l, r);
        } else if (typ == 3) {
            int l, r, a, b; cin >> l >> r >> a >> b;
            tree.apply(l, r, F{a, b});
        } else {
            assert(typ == 4);
            int l, r; cin >> l >> r;
            printf("%lld\n", tree.prod(l, r).sum);
        }
    }
    return 0;
}
Back to top page