Treap

比较基础的平衡树。

Treap

1. 定义

Tree + Heap = Treap

前置知识:BST + Heap

BST 的链接

2. 原理

由于 BST 的复杂度与高度相关,而 BST 容易退化。(如插入一条链)

而可以发现,随机数据下,期望高度为 $O(\log n)$。

Treap 恰好利用了这一点。

对于每一个节点,都额外赋予一个 $val$,并实时维护一个堆(小根或大根均可)。

可以发现,只要 $val$ 确定,树的形态就唯一确定($val$ 都不同)

3. 基本操作

1)右旋 (zig)

1 和 2 都是交换儿子与父亲的操作,前提是不会影响中序遍历。

是将左儿子换到父亲的位置。

先将左儿子换到父亲,将父亲换到右儿子。

将左儿子的左子树换到父亲的左子树,左儿子的右子树换到右儿子的左子树。

画个图理解一下。

看代码

1
2
3
4
5
6
void zig(int &p)
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r);pushup(p);
}

2)左旋 (zag)

是将右儿子换到父亲为位置。

先将右儿子换到父亲,将父亲换到左儿子。

将右儿子的左子树换到左儿子的右子树,父亲的左子树换到左儿子的左子树。

可以发现,左旋与右旋是互逆操作。

画个图理解下。

看代码。

1
2
3
4
5
6
void zag(int &p)
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[p].l);pushup(p);
}

3)插入

首先回溯插入(见 BST),在回溯过程中如果儿子与父亲不满足堆性质,则交换。

4)删除

可以发现,左旋和右旋都可以使一个节点高度降低。

所以将该节点旋转至叶节点,直接删除即可。

请注意,在维护过程中同时注意堆性质,防止旋反。

如果是大根堆,就应该将大的 $val$ 旋到父节点。

3.例题

T1:普通平衡树

题目传送门 Luogu

题目传送门 AcWing

模板题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<ctime>
using namespace std;

const int N=1e5+10,INF=1e9;

struct Node{
int l,r;
int key,val,s,cnt;
}tr[N];

int n,idx,rt;

void pushup(int p)
{
tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+tr[p].cnt;
}

void zig(int &p)
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r);pushup(p);
}

void zag(int &p)
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[p].l);pushup(p);
}

int get_node(int key)
{
tr[++idx].key=key;
tr[idx].val=rand();
tr[idx].s=tr[idx].cnt=1;
return idx;
}

void build()
{
get_node(-INF),get_node(INF);
rt=1,tr[1].r=2;
}

void insert(int &p,int key)
{
if (!p)
{
p=get_node(key);
return;
}
if (tr[p].key==key) tr[p].cnt++;
else if (tr[p].key>key)
{
insert(tr[p].l,key);
if (tr[p].val<tr[tr[p].l].val) zig(p);
}
else
{
insert(tr[p].r,key);
if (tr[tr[p].r].val>tr[p].val) zag(p);
}
pushup(p);
}

void remove(int &p,int key)
{
if (!p) return;
if (tr[p].key==key)
{
if (tr[p].cnt>1) tr[p].cnt--;
else if (tr[p].l||tr[p].r)
{
if (!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val) zig(p),remove(tr[p].r,key);
else zag(p),remove(tr[p].l,key);
}
else p=0;
}
else if (tr[p].key<key) remove(tr[p].r,key);
else remove(tr[p].l,key);
pushup(p);
}

int get_rank_by_key(int p,int key)
{
if (!p) return 0;
if (tr[p].key==key) return tr[tr[p].l].s+1;
if (tr[p].key>key) return get_rank_by_key(tr[p].l,key);
return tr[tr[p].l].s+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}

int get_key_by_rank(int p,int rank)
{
if (!p) return INF;
if (tr[tr[p].l].s>=rank) return get_key_by_rank(tr[p].l,rank);
if (tr[tr[p].l].s+tr[p].cnt>=rank) return tr[p].key;
return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].s-tr[p].cnt);
}

int get_prev(int p,int key)
{
if (!p) return -INF;
if (tr[p].key>=key) return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}

int get_next(int p,int key)
{
if (!p) return INF;
if (tr[p].key<=key) return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}

int main()
{
srand(time(NULL));
build();
cin>>n;
int op,x;
while (n--)
{
scanf("%d %d",&op,&x);
switch (op)
{
case 1:insert(rt,x);break;
case 2:remove(rt,x);break;
case 3:printf("%d\n",get_rank_by_key(rt,x)-1);break;
case 4:printf("%d\n",get_key_by_rank(rt,x+1));break;
case 5:printf("%d\n",get_prev(rt,x));break;
case 6:printf("%d\n",get_next(rt,x));break;
}
}
return 0;
}

T2:[HNOI 2002]营业额统计

题目传送门 Luogu

题目传送门 AcWing

找$min(|a[i]-a[j]|)(1<=j<i)$,直接维护在当前数之前的序列。

直接维护 Treap 即可。

但是我用的是 set 乱搞。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <set>
using namespace std;

typedef long long ll;
const int INF = 0x3f3f3f3f;
set<int> ta;
int n;
ll ans;

int main()
{
cin >> n >> ans;
ta.insert(-INF), ta.insert(ans), ta.insert(INF);
for (int i = 2, x; i <= n; ++ i)
{
scanf("%d", &x);
if (ta.find(x) != ta.end()) continue;
auto iter = ta.insert(x).first;
auto ne = iter, pre = iter;
ne ++, pre --;
ans += min(*ne - *iter, *iter - *pre);
}
cout << ans << endl;
return 0;
}