Splay

比较重要的基础数据结构。

Splay

前置知识:平衡二叉树,Treap / 左旋右旋操作。

1基本原理

每操作一次,均将该节点旋至树根。

2.核心操作

即 Splay 操作.。

定义 $splay(x,k)$ 为将点 x 旋转至 k 的下面。

特别的,$splay(x,0)$ 定义为将 x 旋转至根。

有四种情况,两种分类。

注意,转 x 的意思就是将 x 转到父节点。

假设 z 是 y 的父亲,y 是 x 的父亲。

第一种:成一条链的形状。

先转 y,再转 x。

第二种:成折线的形状。

转两次 x。

一般来说,k 的取值一般只有是 0 或根。

3.支持操作

1)插入

第一种,就是将 x 直接插入。

第二种,将一个序列插到 y 的后面。

主要讲第二种。

首先,将 y 的后继 z。

然后,执行 $splay(y,0)$。

接着,执行 $splay(z,y)$。

现在,由于 z 是 y 的后继,所以 z 一定没有左子树(否则就不是后继)。

直接先将插入序列构造成二叉树,接在 z 左子树即可。

2)删除

直接讲删除 $[L,R]$。

先将执行 $splay(L-1,0)$,然后 $splay(R+1,L-1)$。

此时,$[L,R]$,一定全部在 R+1 的左子树,直接断开来连接即可。

4.例题

T1:Splay/文艺平衡树

题目传送门 AcWing

题目传送门 Luogu

要维护以下信息:

  1. $size$
  2. $flag/lazytag$:表示是否会翻转。

需要维护。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;

const int N=1e5+10;

int n,rt,idx;
struct Node{
int s[2],size,p,v;
int flag;
void init(int _v,int _p)
{
v=_v;p=_p;
}
}t[N];

void pushup(int p)
{
t[p].size=t[t[p].s[0]].size+t[t[p].s[1]].size+1;
}

void pushdown(int p)
{
if (!t[p].flag) return;
swap(t[p].s[0],t[p].s[1]);
t[t[p].s[0]].flag^=1;
t[t[p].s[1]].flag^=1;
t[p].flag=0;
pushup(p);
}

void rotate(int x)
{
// cout<<x<<endl;
int y=t[x].p,z=t[y].p,k=(t[y].s[1]==x);
t[z].s[y==t[z].s[1]]=x,t[x].p=z;
t[y].s[k]=t[x].s[k^1],t[t[x].s[k^1]].p=y;
t[x].s[k^1]=y,t[y].p=x;
pushup(y);pushup(x);
}

void splay(int x,int k)
{
// cout<<x<<' '<<k<<endl;
while (t[x].p!=k)
{
// cout<<x<<endl;
int y=t[x].p,z=t[y].p;
if (z!=k)
if ((x==t[y].s[1])^(y==t[z].s[1])) rotate(x);
else rotate(y);
rotate(x);
}
if (k==0) rt=x;
}

void insert(int v)
{
// cout<<v<<endl;
int u=rt,p=0;
while (u) p=u,u=t[u].s[t[u].v<v];
int now=++idx;
if (p) t[p].s[v>t[p].v]=now;
t[now].init(v,p);
splay(now,0);
}

int get_k(int k)
{
int u=rt;
while (true)
{
pushdown(u);
if (t[t[u].s[0]].size>=k) u=t[u].s[0];
else if (t[t[u].s[0]].size==k-1) return u;
else k-=t[t[u].s[0]].size+1,u=t[u].s[1];
}
return -1;
}

void output(int u)
{
pushdown(u);
if (t[u].s[0]) output(t[u].s[0]);
if (1<=t[u].v&&t[u].v<=n) printf("%d ",t[u].v);
if (t[u].s[1]) output(t[u].s[1]);
}

int main()
{
int m;
cin>>n>>m;
for (int i=0;i<=n+1;++i) insert(i);
while (m--)
{
int l,r;
scanf("%d%d",&l,&r);
l=get_k(l),r=get_k(r+2);
splay(l,0);splay(r,l);
t[t[r].s[0]].flag^=1;
}
output(rt);
return 0;
}

T2:[NOI2004] 郁闷的出纳员

题目传送门 Luogu

题目传送门 AcWing

还是接近于模板,但是要处理一下整个修改。

我们可以记录一个 $delta$,记录整个的偏移量。

差不多就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N=1e5+10,INF=1e9;

struct Node{
int s[2],p;
int val,size;
inline void init(int _val,int _p)
{
val=_val;p=_p;
size=1;
}
}tr[N];

int n,rt,m,tot,delta;

void pushup(int p)
{
tr[p].size=tr[tr[p].s[0]].size+tr[tr[p].s[1]].size+1;
}

void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[x].p=z,tr[z].s[y==tr[z].s[1]]=x;
tr[tr[x].s[k^1]].p=y,tr[y].s[k]=tr[x].s[k^1];
tr[y].p=x,tr[x].s[k^1]=y;
pushup(y);pushup(x);
}

void splay(int x,int k)
{
while (tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if (z!=k)
if ((y==tr[z].s[1])^(tr[y].s[1]==x)) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) rt=x;
}

int get(int val)
{
int u=rt,res;
while (u)
{
if (tr[u].val>=val) res=u,u=tr[u].s[0];
else u=tr[u].s[1];
}
return res;
}

int get_k(int k)
{
int u=rt,tmp=k;
while (true)
{
if (tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if (tr[tr[u].s[0]].size+1==k)
{
return tr[u].val;
}
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}

int insert(int val)
{
int u=rt,p=0;
while (u) p=u,u=tr[u].s[val>tr[u].val];
u=++tot;
if (p) tr[p].s[val>tr[p].val]=u;
tr[u].init(val,p);
splay(u,0);
return u;
}

int main()
{
scanf("%d %d",&n,&m);
int L=insert(-INF),R=insert(INF),cnt=0,x;

char op[5];
while (n--)
{
scanf("%s %d",op,&x);
if (op[0]=='I')
{
if (x>=m) x-=delta,insert(x),cnt++;
}
else if (op[0]=='A') delta+=x;
else if (op[0]=='S')
{
delta-=x;
R=get(m-delta);
splay(R,0);splay(L,R);
tr[L].s[1]=0;
pushup(L);pushup(R);
}
else
{
if (tr[rt].size-2<x) puts("-1");
else printf("%d\n",get_k(tr[rt].size-x)+delta);
}
}
printf("%d\n",cnt-(tr[rt].size-2));
return 0;
}

T3:[HNOI2012] 永无乡

题目传送门 Luogu

题目传送门 AcWing

主要是如何维护哪些点在同一个集合。

首先要使用并查集,同时使用启发式合并。

顾名思义,启发式合并就是将小的合并到大的上面。

没什么说的了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;

const int N = 1e6 + 10;
struct Node
{
int s[2], p;
int key, id;
int size;
void init(int _key, int _id, int _p)
{
key=_key;
id=_id;
p=_p;
size=1;
}
} tr[N];
int n, tot, rt[N], fa[N], m;

void pushup(int p)
{
tr[p].size = tr[tr[p].s[0]].size + tr[tr[p].s[1]].size + 1;
}

void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[x].p=z,tr[z].s[y==tr[z].s[1]]=x;
tr[tr[x].s[k^1]].p=y,tr[y].s[k]=tr[x].s[k^1];
tr[y].p = x, tr[x].s[k^1] = y;
pushup(y);pushup(x);
}

void splay(int x, int k, int b)
{
while (tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[z].s[1] == y) ^ (tr[y].s[1] == x)) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) rt[b] = x;
}

int get_k(int k, int b)
{
int u = rt[b];
while (u)
{
if (tr[tr[u].s[0]].size >= k)
u = tr[u].s[0];
else if (tr[tr[u].s[0]].size == k - 1)
return tr[u].id;
else
k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}

void insert(int key, int id, int b)
{
int u = rt[b], p=0;
while (u)
p = u, u = tr[u].s[key > tr[u].key];
u = ++tot;
if (p)
tr[p].s[key > tr[p].key] = u;
tr[u].init(key, id, p);
splay(u, 0, b);
}

void dfs(int u, int b)
{
if (tr[u].s[0]) dfs(tr[u].s[0], b);
if (tr[u].s[1]) dfs(tr[u].s[1], b);
insert(tr[u].key, tr[u].id, b);
}

int find(int x)
{
if (fa[x] != x)
fa[x] = find(fa[x]);
return fa[x];
}

int main()
{
scanf("%d %d",&n, &m);
int x, y;
for (int i = 1; i <= n; ++ i)
{
scanf("%d", &x);
rt[i]=i;
fa[i]=i;
tr[i].init(x,i,0);
}
tot=n;

while (m--)
{
scanf("%d %d", &x, &y);
x = find(x);
y = find(y);
if (x==y) continue;
if (tr[rt[x]].size > tr[rt[y]].size)
swap(x, y);
dfs(rt[x], y);
fa[x] = y;
}
scanf("%d", &m);
char op[5];
while (m--)
{
scanf("%s %d %d", op, &x, &y);
if (op[0] == 'Q')
{
x = find(x);
if (tr[rt[x]].size < y)
puts("-1");
else
printf("%d\n", get_k(y, x));
}
else
{
x = find(x);
y = find(y);
if (x == y)
continue;
if (tr[rt[x]].size > tr[rt[y]].size)
swap(x, y);
dfs(rt[x], y);
fa[x] = y;
}
}
return 0;
}

T4:[NOI2005] 维护数列

题目传送门 Luogu

题目传送门 AcWing

很麻烦,我们可以一个一个的看。

(先咕着)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N=1e6+10,INF=1e9+7;

struct Node{
int s[2];
int val,fa,size;
bool flag,sm;
int ls,rs,tots,sum;
inline void init(int _val,int _fa)
{
sum=val=_val;fa=_fa;s[0]=s[1]=0;
flag=sm=false;
if (val>=0) ls=rs=tots=val;
else ls=rs=0,tots=val;
size=1;
}
}tr[N];

int q[N],st,rt;
int tmp[N];

inline void pushup(int x)
{
Node &u=tr[x],&l=tr[tr[x].s[0]],&r=tr[tr[x].s[1]];
u.size=l.size+r.size+1;
u.sum=l.sum+r.sum+u.val;
u.ls=max(l.ls,l.sum+u.val+r.ls);
u.rs=max(r.rs,r.sum+u.val+l.rs);
u.tots=max(max(l.tots,r.tots),l.rs+u.val+r.ls);
}

inline void pushdown(int x)
{
Node &u=tr[x],&l=tr[tr[x].s[0]],&r=tr[tr[x].s[1]];
if (u.sm)
{
u.flag=0;u.sm=0;
if (u.s[0]) l.sm=1,l.val=u.val,l.sum=u.val*l.size;
if (u.s[1]) r.sm=1,r.val=u.val,r.sum=u.val*r.size;
if (u.s[0])
if (l.val>=0) l.ls=l.rs=l.tots=l.sum;
else l.ls=l.rs=0,l.tots=l.val;
if (u.s[1])
if (r.val>=0) r.ls=r.rs=r.tots=r.sum;
else r.ls=r.rs=0,r.tots=r.val;
// if (l==N-1||r==N-1) puts("Failed");
}
if (u.flag)
{
u.flag=0;
l.flag^=1;r.flag^=1;
swap(l.ls,l.rs);swap(l.s[0],l.s[1]);
swap(r.ls,r.rs);swap(r.s[0],r.s[1]);
// if (l==N-1||r==N-1) puts("Failed");
}
// pushup(x);
}

void rotate(int x)
{
// pushdown(x);
int y=tr[x].fa,z=tr[y].fa;
int k=x==tr[y].s[1];
tr[z].s[y==tr[z].s[1]]=x,tr[x].fa=z;
tr[tr[x].s[k^1]].fa=y,tr[y].s[k]=tr[x].s[k^1];
tr[y].fa=x,tr[x].s[k^1]=y;
pushup(y);pushup(x);
}

void splay(int x,int k)
{
while (tr[x].fa!=k)
{
// pushdown(x);
int y=tr[x].fa,z=tr[y].fa;
// pushdown(z);pushdown(y);
if (z!=k)
if ((x==tr[y].s[1])^(y==tr[z].s[1])) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) rt=x;
}

int find(int k)
{
// printf("find:%d(%d,%d,%d) %d\n",x,tr[x].val,tr[x].size,tr[tr[x].s[0]].size,pos);
int u=rt;
while (u)
{
pushdown(u);
if (tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if (tr[tr[u].s[0]].size+1==k) return u;
else k-=tr[tr[u].s[0]].size+1, u=tr[u].s[1];
}
}

void dfs(int x)
{
q[st++]=x;
if (tr[x].s[0]) dfs(tr[x].s[0]);
if (tr[x].s[1]) dfs(tr[x].s[1]);
}

int make_tree(int fa,int l,int r)
{
// printf("make_tree %d %d %d\n",fa,l,r);
int mid=l+r>>1,u=q[--st];
tr[u].init(tmp[mid],fa);
if (l==r)
{
if (tmp[mid]>=0) tr[u].tots=tr[u].ls=tr[u].rs=tmp[mid];
else tr[u].tots=tmp[mid],tr[u].ls=tr[u].rs=0;
return u;
}
if (l<mid) tr[u].s[0]=make_tree(u,l,mid-1);
if (r>mid) tr[u].s[1]=make_tree(u,mid+1,r);
pushup(u);
return u;
}

void Insert(int pos,int tot)
{
int l=find(pos+1),r=find(pos+2);
splay(l,0);splay(r,l);
// pushdown(r);pushdown(l);
tr[r].s[0]=make_tree(r,0,tot-1);
pushup(r);pushup(l);
}

void remove(int pos,int tot)
{
int l=find(pos),r=find(pos+tot+1);
splay(l,0);splay(r,l);
// pushdown(r);
// printf("%d:%d %d:%d\n",l,tr[l].val,r,tr[r].val);
dfs(tr[r].s[0]);
tr[r].s[0]=0;pushup(r);pushup(l);
}

void make_same(int pos,int tot,int x)
{
int l=find(pos),r=find(pos+tot+1);
// printf("%d:%d %d:%d\n",l,tr[l].val,r,tr[r].val);
splay(l,0);splay(r,l);
Node &now=tr[tr[r].s[0]];
now.val=x;now.sm=1;
now.sum=x*now.size;
now.flag=0;
if (x>0)
{
// if (now==N-1) printf("Failed:::::%d %d %d %d\n",tr[now].ls,tr[now].rs,tr[now].sum,tr[now].tots);
now.ls=now.rs=now.tots=now.sum;
}
else now.ls=now.rs=0,now.tots=x;
pushup(r);pushup(l);
}

void reverse(int pos,int tot)
{
int l=find(pos),r=find(pos+tot+1);
// printf("%d:%d %d:%d\n",l,tr[l].val,r,tr[r].val);
splay(l,0);splay(r,l);
Node &now=tr[tr[r].s[0]];
now.flag^=1;
swap(now.s[0],now.s[1]);
swap(now.ls,now.rs);
pushup(r);pushup(l);
}

int get_sum(int pos,int tot)
{
int l=find(pos),r=find(pos+tot+1);
// printf("%d:%d %d:%d\n",l,tr[l].val,r,tr[r].val);
splay(l,0);splay(r,l);
return tr[tr[r].s[0]].sum;
}

void output(int x)
{
pushdown(x);
if (tr[x].s[0]) output(tr[x].s[0]);
printf("%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\n",x,tr[x].val,tr[x].sum,tr[x].ls,tr[x].rs,tr[x].tots,tr[x].s[0],tr[x].s[1]);
if (tr[x].s[1]) output(tr[x].s[1]);
}

int main()
{
// freopen("P2042_2.in","r",stdin);
// freopen("myans.out","w",stdout);
int n,m;
for (int i=1;i<N;++i) q[st++]=i;
scanf("%d %d",&n,&m);
tr[0].tots=-INF;
tmp[0]=tmp[n+1]=-INF;
for (int i=1;i<=n;++i) scanf("%d",&tmp[i]);
rt=make_tree(0,0,n+1);
int pos,tot,x;
char op[20];
while (m--)
{
scanf("%s",op);
if (strcmp(op,"INSERT")==0)
{
scanf("%d %d",&pos,&tot);
for (int i=0;i<tot;++i) scanf("%d",&tmp[i]);
Insert(pos,tot);
}
else if (strcmp(op,"DELETE")==0)
{
scanf("%d %d",&pos,&tot);
remove(pos,tot);
}
else if (strcmp(op,"MAKE-SAME")==0)
{
scanf("%d %d %d",&pos,&tot,&x);
make_same(pos,tot,x);
}
else if (strcmp(op,"REVERSE")==0)
{
scanf("%d %d",&pos,&tot);
reverse(pos,tot);
}
else if (strcmp(op,"GET-SUM")==0)
{
scanf("%d %d",&pos,&tot);
printf("%d\n",get_sum(pos,tot));
}
else if (strcmp(op,"MAX-SUM")==0)
{
// output(rt);
printf("%d\n",tr[rt].tots);
}
}
return 0;
}