Weight Balanced Leafy Tree

所谓 Leafy Tree,就是把序列信息都保存在叶子节点上的树. 非叶子节点上,只保存用于维护树形态的变量.(比如说线段树就是 Leafy Tree). 由于是二叉树,所以我们将要学习的 Leafy Tree 占用的空间是 2n=O(n)\approx 2n=O(n) 的.

以下我们使用指针写法.

节点维护的信息

我们定义节点如下:

1
2
3
4
5
6
7
8
9
static constexpr int inf = 1e9;
struct node {
std::shared_ptr<node> l, r;
int size;
int val;

node(int v = inf) : l{nullptr}, r{nullptr}, size(v != inf), val(v) {}
};
using ptr = std::shared_ptr<node>;

这里,sizeval 比较重要,如果我们希望的是维护元素之间的相对大小,那么,我们需要在 DFS 时利用 val 定位元素位置;如果我们希望拆散、重组序列(即不关心元素的小大顺序),那么我们就需要利用 size 对子树进行拆分.

  • tag 就是像线段树那样的 tag.
  • 只有叶子的 val 保存的是序列的元素,非叶子节点的 val 由两个节点合并而来.

push_up()

就像线段树一样写代码

1
2
3
4
5
6
void push_up(const ptr &p) {
p->size = 0;
if (p->l != nullptr) p->size += p->l->size;
if (p->r != nullptr) p->size += p->r->size;
p->val = p->r->val;
}

create()

也没什么好说的,就是统一了一下 API 而已.

1
2
// create a non-leaf node (if v == inf) or a leaf (if v != inf)
ptr create(int v = inf) { return std::make_shared<node>(v); }

destroy()

销毁一个节点. 直接设为 nullptr 即可

1
void destroy(ptr &z) { z = nullptr; }

join(ptr, ptr)

直接新建节点,将两棵子树直接拼接起来. trivial

1
2
3
4
5
6
ptr join(const ptr &l, const ptr &r) {
ptr z = create();
z->l = l, z->r = r;
push_up(z);
return z;
}

cut(ptr)

把根节点的两棵子树拆下来. 然后删除根节点

1
2
3
4
5
6
// split a tree into 2.
std::pair<ptr, ptr> cut(ptr &z) {
auto l = z->l, r = z->r;
destroy(z);
return {l, r};
}


Template Code

Simple Code for Maintaining Sequence
指针风格(容量大、但速度慢)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class WBLT {
using T = int;

public:
struct node {
int size;
T val, sum;
std::array<std::unique_ptr<node>, 2> ch;

node() : size{0}, val{0}, sum{0}, ch{nullptr, nullptr} {}
node(T v) : size{1}, val{v}, sum{v}, ch{nullptr, nullptr} {}
};
using pointer = std::unique_ptr<node>;

public:
friend WBLT merge(WBLT &a, WBLT &b) {
WBLT res;
res.root = a.merge(a.root, b.root);
return res;
}
void merge(WBLT &b) { root = merge(root, b.root); }

/**
* @brief Split the current WBLT into two WBLTs, with size k (returned) and n-k (current)
*/
WBLT split(int k) {
WBLT res;
std::tie(res.root, root) = split(root, k);
return res;
}

/**
* @brief Insert a value into the WBLT by its rank, ans assume it's the k-th element. (1-indexed)
*/
void insertByRank(int k, T val) {
if (k == 1 || root == nullptr) {
auto p = newLeaf(val);
root = merge(p, root);
return;
} else if (k == root->size + 1) {
auto p = newLeaf(val);
root = merge(root, p);
return;
}
assert(k <= root->size + 1 && k >= 1);
auto [left, right] = split(root, k - 1);
auto p = newLeaf(val);
root = merge(left, p = merge(p, right));
}

/**
* @brief Get a value from the WBLT by its rank. (1-indexed)
*/
T getByRank(int kth) {
assert(kth <= root->size && kth >= 1);
auto [left, right] = split(root, kth - 1);
auto [mid, newRight] = split(right, 1);
T f = mid->val;

root = merge(left, right = merge(mid, newRight));
return f;
}

/**
* @brief Build WBLT from a vector or array
*/
template <typename U>
void build(const U &vector, int l, int r) {
_build(vector, l, r, root);
}

/**
* @brief Get the root of WBLT
*/
pointer &getRoot() { return root; }

protected:
pointer newNode() { return std::make_unique<node>(); }
void deleteNode(pointer &p) { p = nullptr; }
pointer newLeaf(T val) { return std::make_unique<node>(val); }
pointer join(pointer &left, pointer &right) {
pointer root = newNode();
root->ch[0] = std::move(left), root->ch[1] = std::move(right);
pushUp(root);
return root;
}
std::pair<pointer, pointer> cut(pointer &p) {
if (p == nullptr)
return {nullptr, nullptr};
pointer left = std::move(p->ch[0]), right = std::move(p->ch[1]);
deleteNode(p);
return {std::move(left), std::move(right)};
}
void rotate(pointer &rt, int r) {
auto [a, b] = cut(rt);
if (r) {
auto [c, d] = cut(b);
rt = join(b = join(a, c), d);
} else {
auto [c, d] = cut(a);
rt = join(c, a = join(d, b));
}
}
void pushUp(pointer &p) {
assert(p->ch[0] && p->ch[1]);
p->size = p->ch[0]->size + p->ch[1]->size;
p->val = p->ch[1]->val;
p->sum = p->ch[0]->sum + p->ch[1]->sum;
}
// functions to help WBLT keep balanced
bool heavy(int sx, int sy) { return sx > sy * 3; }
bool doubleRotationRequired(pointer &rt, int r) { return rt->ch[!r]->size > rt->ch[r]->size * 2; }
void balance(pointer &rt) {
if (rt->size == 1)
return;
if (heavy(rt->ch[0]->size, rt->ch[1]->size) || heavy(rt->ch[1]->size, rt->ch[0]->size)) {
auto [a, b] = cut(rt);
rt = merge(a, b);
}
}
// functions to build WBLT
template <typename U>
void _build(const U &vector, int l, int r, pointer &rt) {
if (rt == nullptr)
rt = newNode();
if (l == r) {
rt = newLeaf(vector[l]);
return;
}
int mid = l + ((r - l) >> 1);
_build(vector, l, mid, rt->ch[0]);
_build(vector, mid + 1, r, rt->ch[1]);
pushUp(rt);
}

/**
* @brief Split the WBLT into two WBLTs, with size k and n-k.
* @warning This operation invalidates the current root.
*/
std::pair<pointer, pointer> split(pointer &rt, int k) {
if (rt == nullptr)
return {nullptr, nullptr};
if (k <= 0)
return {nullptr, std::move(rt)};
if (k >= rt->size)
return {std::move(rt), nullptr};

auto [left, right] = cut(rt);
if (k <= left->size) {
auto [l, r] = split(left, k);
return {std::move(l), merge(r, right)};
} else {
auto [l, r] = split(right, k - left->size);
return {merge(left, l), std::move(r)};
}
}

/**
* @brief Merge two WBLTs into one WBLT (without re-ordering)
*/
pointer merge(pointer &left, pointer &right) {
if (left == nullptr || right == nullptr)
return left == nullptr ? std::move(right) : std::move(left);

if (heavy(left->size, right->size)) {
auto [a, b] = cut(left);
if (heavy(b->size + right->size, a->size)) {
auto [c, d] = cut(b);
return merge(left = merge(a, c), b = merge(d, right));
} else
return merge(a, left = merge(b, right));
} else if (heavy(right->size, left->size)) {
auto [a, b] = cut(right);
if (heavy(a->size + left->size, b->size)) {
auto [c, d] = cut(a);
return merge(right = merge(left, c), a = merge(d, b));
} else
return merge(right = merge(left, a), b);
} else
return std::move(join(left, right));
}

private:
pointer root{nullptr};
};
数组风格(速度更快,但是 std::array<> 受内存制约,容量小)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class WBLT {
using T = int;
using index = int;
static constexpr int CacheSize = 10e5 + 10;

public:
struct node {
int size;
T val, sum;
std::array<index, 2> ch;

node() : size{0}, val{0}, sum{0}, ch{0, 0} {}
node(T v) : size{1}, val{v}, sum{v}, ch{0, 0} {}
};

public:
/// @brief Access the specific element in the tree
node &operator[](index i) { return pool[i]; }

/// @brief Split the current WBLT into two WBLTs, with size k (returned) and n-k (current)
index split(int k) {
int res;
std::tie(res, root) = split(root, k);
return res;
}

index merge(index x, index y) {
if (x == 0 || y == 0)
return x == 0 ? y : x;

if (heavy(pool[x].size, pool[y].size)) {
auto [a, b] = cut(x);
if (heavy(pool[b].size + pool[y].size, pool[a].size)) {
auto [c, d] = cut(b);
return merge(merge(a, c), merge(d, y));
} else
return merge(a, merge(b, y));
} else if (heavy(pool[y].size, pool[x].size)) {
auto [a, b] = cut(y);
if (heavy(pool[a].size + pool[x].size, pool[b].size)) {
auto [c, d] = cut(a);
return merge(merge(x, c), merge(d, b));
} else
return merge(merge(x, a), b);
} else
return join(x, y);
}
void merge(int b) { root = merge(root, b); }

index getRoot() const { return root; }

void insertByRank(int k, T val) {
if (k == 1 || root == 0) {
auto p = newLeaf(val);
root = merge(p, root);
return;
} else if (k == pool[root].size + 1) {
auto p = newLeaf(val);
root = merge(root, p);
return;
}
assert(k <= pool[root].size + 1 && k >= 1);
auto [left, right] = split(root, k - 1);
auto p = newLeaf(val);
root = merge(left, merge(p, right));
}

T getByRank(int kth) {
assert(kth <= pool[root].size && kth >= 1);
auto [left, right] = split(root, kth - 1);
auto [mid, newRight] = split(right, 1);
T f = pool[mid].val;
root = merge(left, merge(mid, newRight));
return f;
}

template <typename U>
void build(U vector, int l, int r) {
root = _build(vector, l, r);
}

protected:
index newNode() {
if (trash.empty())
return ++poolid;
else {
int res = trash.back();
trash.pop_back();
return res;
}
}
void deleteNode(index &p) { trash.push_back(p), pool[p] = node(), p = 0; }
index newLeaf(T val) {
index p = newNode();
pool[p] = node(val);
return p;
}

/// @brief Attach two nodes to a new root
index join(index left, index right) {
index nroot = newNode();
pool[nroot].ch[0] = left, pool[nroot].ch[1] = right;
pushUp(nroot);
return nroot;
}

/// @brief Remove the parent node and return its children
std::pair<index, index> cut(index &p) {
auto [y, z] = pool[p].ch;
deleteNode(p);
return {y, z};
}

/// @brief Single-rotation, r = 0 for left, r = 1 for right.
void rotate(index &rt, int r) {
auto [a, b] = cut(rt);
if (r) {
auto [c, d] = cut(b);
rt = join(join(a, c), d);
} else {
auto [c, d] = cut(a);
rt = join(c, join(d, b));
}
}
bool heavy(int sx, int sy) { return sx > sy * 3; }
bool doubleRotationRequired(index &rt, int r) {
return pool[rt].ch[!r] && pool[pool[rt].ch[!r]].size > pool[pool[rt].ch[r]].size * 2;
}
void balance(index &rt) {
if (pool[rt].size == 1)
return;
if (heavy(pool[pool[rt].ch[0]].size, pool[pool[rt].ch[1]].size)
|| heavy(pool[pool[rt].ch[1]].size, pool[pool[rt].ch[0]].size)) {
auto [a, b] = cut(rt);
rt = merge(a, b);
}
}

/// @brief Maintain the supporting info in WBLT non-leaf node.
void pushUp(index p) {
assert(pool[p].ch[0] && pool[p].ch[1]);
index ls = pool[p].ch[0], rs = pool[p].ch[1];
pool[p].size = pool[ls].size + pool[rs].size, pool[p].val = pool[rs].val;
pool[p].sum = pool[ls].sum + pool[rs].sum;
}

std::pair<index, index> split(int x, int k) {
if (x == 0)
return {0, 0};
if (k <= 0)
return {0, x};
if (k >= pool[x].size)
return {x, 0};

auto [left, right] = cut(x);
if (k <= pool[left].size) {
auto [l, r] = split(left, k);
return {l, merge(r, right)};
} else {
auto [l, r] = split(right, k - pool[left].size);
return {merge(left, l), r};
}
}

// helper function to build WBLT
template <typename U>
index _build(const U &vector, int l, int r) {
if (l == r)
return newLeaf(vector[l]);
int mid = l + ((r - l) >> 1);
return join(_build(vector, l, mid), _build(vector, mid + 1, r));
}

private:
index root{0}, poolid{0};
// std::array<node, CacheSize> pool;
std::vector<node> pool = std::vector<node>(CacheSize);
std::vector<int> trash;
};