Weight Balanced Leafy Tree

Template Code

Simple Code for Maintaining Sequence
指针风格(容量大、但速度慢)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class WBLT {
using T = int;

public:
struct node {
int size;
T val, sum;
std::array<std::unique_ptr<node>, 2> ch;

node() : size{0}, val{0}, sum{0}, ch{nullptr, nullptr} {}
node(T v) : size{1}, val{v}, sum{v}, ch{nullptr, nullptr} {}
};
using pointer = std::unique_ptr<node>;

public:
friend WBLT merge(WBLT &a, WBLT &b) {
WBLT res;
res.root = a.merge(a.root, b.root);
return res;
}
void merge(WBLT &b) { root = merge(root, b.root); }

/**
* @brief Split the current WBLT into two WBLTs, with size k (returned) and n-k (current)
*/
WBLT split(int k) {
WBLT res;
std::tie(res.root, root) = split(root, k);
return res;
}

/**
* @brief Insert a value into the WBLT by its rank, ans assume it's the k-th element. (1-indexed)
*/
void insertByRank(int k, T val) {
if (k == 1 || root == nullptr) {
auto p = newLeaf(val);
root = merge(p, root);
return;
} else if (k == root->size + 1) {
auto p = newLeaf(val);
root = merge(root, p);
return;
}
assert(k <= root->size + 1 && k >= 1);
auto [left, right] = split(root, k - 1);
auto p = newLeaf(val);
root = merge(left, p = merge(p, right));
}

/**
* @brief Get a value from the WBLT by its rank. (1-indexed)
*/
T getByRank(int kth) {
assert(kth <= root->size && kth >= 1);
auto [left, right] = split(root, kth - 1);
auto [mid, newRight] = split(right, 1);
T f = mid->val;

root = merge(left, right = merge(mid, newRight));
return f;
}

/**
* @brief Build WBLT from a vector or array
*/
template <typename U>
void build(const U &vector, int l, int r) {
_build(vector, l, r, root);
}

/**
* @brief Get the root of WBLT
*/
pointer &getRoot() { return root; }

protected:
pointer newNode() { return std::make_unique<node>(); }
void deleteNode(pointer &p) { p = nullptr; }
pointer newLeaf(T val) { return std::make_unique<node>(val); }
pointer join(pointer &left, pointer &right) {
pointer root = newNode();
root->ch[0] = std::move(left), root->ch[1] = std::move(right);
pushUp(root);
return root;
}
std::pair<pointer, pointer> cut(pointer &p) {
if (p == nullptr)
return {nullptr, nullptr};
pointer left = std::move(p->ch[0]), right = std::move(p->ch[1]);
deleteNode(p);
return {std::move(left), std::move(right)};
}
void rotate(pointer &rt, int r) {
auto [a, b] = cut(rt);
if (r) {
auto [c, d] = cut(b);
rt = join(b = join(a, c), d);
} else {
auto [c, d] = cut(a);
rt = join(c, a = join(d, b));
}
}
void pushUp(pointer &p) {
assert(p->ch[0] && p->ch[1]);
p->size = p->ch[0]->size + p->ch[1]->size;
p->val = p->ch[1]->val;
p->sum = p->ch[0]->sum + p->ch[1]->sum;
}
// functions to help WBLT keep balanced
bool heavy(int sx, int sy) { return sx > sy * 3; }
bool doubleRotationRequired(pointer &rt, int r) { return rt->ch[!r]->size > rt->ch[r]->size * 2; }
void balance(pointer &rt) {
if (rt->size == 1)
return;
if (heavy(rt->ch[0]->size, rt->ch[1]->size) || heavy(rt->ch[1]->size, rt->ch[0]->size)) {
auto [a, b] = cut(rt);
rt = merge(a, b);
}
}
// functions to build WBLT
template <typename U>
void _build(const U &vector, int l, int r, pointer &rt) {
if (rt == nullptr)
rt = newNode();
if (l == r) {
rt = newLeaf(vector[l]);
return;
}
int mid = l + ((r - l) >> 1);
_build(vector, l, mid, rt->ch[0]);
_build(vector, mid + 1, r, rt->ch[1]);
pushUp(rt);
}

/**
* @brief Split the WBLT into two WBLTs, with size k and n-k.
* @warning This operation invalidates the current root.
*/
std::pair<pointer, pointer> split(pointer &rt, int k) {
if (rt == nullptr)
return {nullptr, nullptr};
if (k <= 0)
return {nullptr, std::move(rt)};
if (k >= rt->size)
return {std::move(rt), nullptr};

auto [left, right] = cut(rt);
if (k <= left->size) {
auto [l, r] = split(left, k);
return {std::move(l), merge(r, right)};
} else {
auto [l, r] = split(right, k - left->size);
return {merge(left, l), std::move(r)};
}
}

/**
* @brief Merge two WBLTs into one WBLT (without re-ordering)
*/
pointer merge(pointer &left, pointer &right) {
if (left == nullptr || right == nullptr)
return left == nullptr ? std::move(right) : std::move(left);

if (heavy(left->size, right->size)) {
auto [a, b] = cut(left);
if (heavy(b->size + right->size, a->size)) {
auto [c, d] = cut(b);
return merge(left = merge(a, c), b = merge(d, right));
} else
return merge(a, left = merge(b, right));
} else if (heavy(right->size, left->size)) {
auto [a, b] = cut(right);
if (heavy(a->size + left->size, b->size)) {
auto [c, d] = cut(a);
return merge(right = merge(left, c), a = merge(d, b));
} else
return merge(right = merge(left, a), b);
} else
return std::move(join(left, right));
}

private:
pointer root{nullptr};
};
数组风格(速度更快,但是 std::array<> 受内存制约,容量小)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class WBLT {
using T = int;
using index = int;
static constexpr int CacheSize = 10e5 + 10;

public:
struct node {
int size;
T val, sum;
std::array<index, 2> ch;

node() : size{0}, val{0}, sum{0}, ch{0, 0} {}
node(T v) : size{1}, val{v}, sum{v}, ch{0, 0} {}
};

public:
/// @brief Access the specific element in the tree
node &operator[](index i) { return pool[i]; }

/// @brief Split the current WBLT into two WBLTs, with size k (returned) and n-k (current)
index split(int k) {
int res;
std::tie(res, root) = split(root, k);
return res;
}

index merge(index x, index y) {
if (x == 0 || y == 0)
return x == 0 ? y : x;

if (heavy(pool[x].size, pool[y].size)) {
auto [a, b] = cut(x);
if (heavy(pool[b].size + pool[y].size, pool[a].size)) {
auto [c, d] = cut(b);
return merge(merge(a, c), merge(d, y));
} else
return merge(a, merge(b, y));
} else if (heavy(pool[y].size, pool[x].size)) {
auto [a, b] = cut(y);
if (heavy(pool[a].size + pool[x].size, pool[b].size)) {
auto [c, d] = cut(a);
return merge(merge(x, c), merge(d, b));
} else
return merge(merge(x, a), b);
} else
return join(x, y);
}
void merge(int b) { root = merge(root, b); }

index getRoot() const { return root; }

void insertByRank(int k, T val) {
if (k == 1 || root == 0) {
auto p = newLeaf(val);
root = merge(p, root);
return;
} else if (k == pool[root].size + 1) {
auto p = newLeaf(val);
root = merge(root, p);
return;
}
assert(k <= pool[root].size + 1 && k >= 1);
auto [left, right] = split(root, k - 1);
auto p = newLeaf(val);
root = merge(left, merge(p, right));
}

T getByRank(int kth) {
assert(kth <= pool[root].size && kth >= 1);
auto [left, right] = split(root, kth - 1);
auto [mid, newRight] = split(right, 1);
T f = pool[mid].val;
root = merge(left, merge(mid, newRight));
return f;
}

template <typename U>
void build(U vector, int l, int r) {
root = _build(vector, l, r);
}

protected:
index newNode() {
if (trash.empty())
return ++poolid;
else {
int res = trash.back();
trash.pop_back();
return res;
}
}
void deleteNode(index &p) { trash.push_back(p), pool[p] = node(), p = 0; }
index newLeaf(T val) {
index p = newNode();
pool[p] = node(val);
return p;
}

/// @brief Attach two nodes to a new root
index join(index left, index right) {
index nroot = newNode();
pool[nroot].ch[0] = left, pool[nroot].ch[1] = right;
pushUp(nroot);
return nroot;
}

/// @brief Remove the parent node and return its children
std::pair<index, index> cut(index &p) {
auto [y, z] = pool[p].ch;
deleteNode(p);
return {y, z};
}

/// @brief Single-rotation, r = 0 for left, r = 1 for right.
void rotate(index &rt, int r) {
auto [a, b] = cut(rt);
if (r) {
auto [c, d] = cut(b);
rt = join(join(a, c), d);
} else {
auto [c, d] = cut(a);
rt = join(c, join(d, b));
}
}
bool heavy(int sx, int sy) { return sx > sy * 3; }
bool doubleRotationRequired(index &rt, int r) {
return pool[rt].ch[!r] && pool[pool[rt].ch[!r]].size > pool[pool[rt].ch[r]].size * 2;
}
void balance(index &rt) {
if (pool[rt].size == 1)
return;
if (heavy(pool[pool[rt].ch[0]].size, pool[pool[rt].ch[1]].size)
|| heavy(pool[pool[rt].ch[1]].size, pool[pool[rt].ch[0]].size)) {
auto [a, b] = cut(rt);
rt = merge(a, b);
}
}

/// @brief Maintain the supporting info in WBLT non-leaf node.
void pushUp(index p) {
assert(pool[p].ch[0] && pool[p].ch[1]);
index ls = pool[p].ch[0], rs = pool[p].ch[1];
pool[p].size = pool[ls].size + pool[rs].size, pool[p].val = pool[rs].val;
pool[p].sum = pool[ls].sum + pool[rs].sum;
}

std::pair<index, index> split(int x, int k) {
if (x == 0)
return {0, 0};
if (k <= 0)
return {0, x};
if (k >= pool[x].size)
return {x, 0};

auto [left, right] = cut(x);
if (k <= pool[left].size) {
auto [l, r] = split(left, k);
return {l, merge(r, right)};
} else {
auto [l, r] = split(right, k - pool[left].size);
return {merge(left, l), r};
}
}

// helper function to build WBLT
template <typename U>
index _build(const U &vector, int l, int r) {
if (l == r)
return newLeaf(vector[l]);
int mid = l + ((r - l) >> 1);
return join(_build(vector, l, mid), _build(vector, mid + 1, r));
}

private:
index root{0}, poolid{0};
// std::array<node, CacheSize> pool;
std::vector<node> pool = std::vector<node>(CacheSize);
std::vector<int> trash;
};