树套树

本质是一个区间通过线段树划分为 $\log n$ 个区间,每一个区间在分别维护,一般用 $O(\log n)$ 的数据结构,所以时间复杂度为 $O(n\log ^ 2 n)$,空间复杂度为 $O(n\log n)$,使用时注意空间。

树套树

1. 定义

顾名思义,就是一棵树套着另一棵树。

例如,对于每一个线段树的节点所维护的区间,都用 Splay 维护成有序序列。

一般外层是线段树或者树状数组,内层是一个平衡树或者线段树。

一般内层使用 STL。

请注意,如果你没有学过以上知识,请看我的其他博客。

我们通过例题来理解。

2. 例题

T1:树套树-简单版

题目传送门 AcWing

假设没有区间的限制,那么我们就可以直接使用 lower_bound 等函数即可(使用 set)。

如何加上区间的限制呢?

那么,我们用一个线段树维护区间,并对每一个节点都建立一个 set,存放整个区间的有序序列。

回顾线段树,相当于是讲一个区间维护成 $\log n$ 个区间。

对于每一个区间,都返回前驱即可。

单次复杂度为 $O(\log^2 n)$。

如果单次修改,我们也像线段树一样,删除在该区间的数,插入新的树即可。

时间复杂度为 $O(m\log^2n)$,空间复杂度为 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include<bits/stdc++.h>
#define l(p) (p<<1)
#define r(p) (p<<1|1)
using namespace std;

const int N=5e4+10,INF=1e9;
struct Node{
int l,r;
multiset <int> s;
}tr[N<<2];
int a[N],n;

void build(int p,int l,int r)
{
tr[p]=(Node){l,r};
tr[p].s.clear();
tr[p].s.insert(-INF),tr[p].s.insert(INF);
for (int i=l;i<=r;++i) tr[p].s.insert(a[i]);
if (l==r) return ;
int mid=l+r>>1;
build(l(p),l,mid);
build(r(p),mid+1,r);
}

void modify(int p,int x,int val)
{
tr[p].s.erase(tr[p].s.find(a[x]));
tr[p].s.insert(val);+
if (tr[p].l==tr[p].r) return;
int mid=tr[p].l+tr[p].r>>1;
if (x<=mid) modify(l(p),x,val);
else modify(r(p),x,val);
}

int query(int p,int l,int r,int x)
{
if (tr[p].l>=l&&tr[p].r<=r)
{
auto it=tr[p].s.lower_bound(x);
--it;return *it;
}
int mid=tr[p].l+tr[p].r>>1,ans=-INF;
if (l<=mid) ans=max(ans,query(l(p),l,r,x));
if (r>mid) ans=max(ans,query(r(p),l,r,x));
return ans;
}

int main()
{
int op,l,r,x,m;
cin>>n>>m;
for (int i=1;i<=n;++i) scanf("%d",a+i);
build(1,1,n);
while (m--)
{
scanf("%d",&op);
if (op==1)
{
scanf("%d %d",&l,&x);
modify(1,l,x);
a[l]=x;
}
else{
scanf("%d %d %d",&l,&r,&x);
printf("%d\n",query(1,l,r,x));
}
}
return 0;
}

T2:树套树

题目传送门 Luogu

题目传送门 AcWing

像上道题一样,我们使用树套树,线段树加 BST。

由于 set 不能维护当前子树的大小,我们就不能使用 STL。

手写 Splay/Treap 等平衡树,外面套一个线段树。

代码很长,至少 200 行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define l(p) (p<<1)
#define r(p) (p<<1|1)
using namespace std;

const int N=1e5+10,INF=0x7fffffff;

struct Node{
int l,r,key,val;
int size,cnt;
}tr[20*N];

struct Seg{
int l,r;
int rt;
}seg[4*N];

int tot=0,pos,a[N],n,m;

int get_node(int key)
{
tr[++tot]=(Node){0,0,key,rand(),1,1};
return tot;
}

void pushup(int p)
{
tr[p].size=tr[tr[p].l].size+tr[tr[p].r].size+tr[p].cnt;
}

void zig(int &p)
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r);pushup(p);
}

void zag(int &p)
{
int q=tr[p].r;
tr[p].r=tr[q].l;tr[q].l=p,p=q;
pushup(tr[p].l);pushup(p);
}

void build(int pos)
{
get_node(-INF),get_node(INF);
tr[tot-1].r=tot,seg[pos].rt=tot-1;
pushup(tot-1);
}

void insert(int &p,int key)
{
if (!p)
{
p=get_node(key);
return;
}
if (tr[p].key==key)
{
tr[p].cnt++;
}
else if (tr[p].key>key)
{
insert(tr[p].l,key);
if (tr[p].val<tr[tr[p].l].val) zig(p);
}
else{
insert(tr[p].r,key);
if (tr[p].val<tr[tr[p].r].val) zag(p);
}
pushup(p);
}

void remove(int &p,int key)
{
if (tr[p].key==key)
{
if (tr[p].cnt>1) tr[p].cnt--;
else
{
if (tr[p].l||tr[p].r)
{
if (!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val) zig(p),remove(tr[p].r,key);
else zag(p),remove(tr[p].l,key);
pushup(p);
}
else p=0;
}
pushup(p);
return;
}
if (tr[p].key>key) remove(tr[p].l,key);
else remove(tr[p].r,key);
pushup(p);
}

int get_rank_by_key(int p,int key)
{
if (!p) return 0;
if (tr[p].key==key) return tr[tr[p].l].size;
if (tr[p].key>key) return get_rank_by_key(tr[p].l,key);
return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}

int get_key_by_rank(int p,int rank)
{
if (!p) return INF;
if (tr[tr[p].l].size>=rank) return get_key_by_rank(tr[p].l,rank);
if (tr[tr[p].l].size+tr[p].cnt>=rank) return tr[p].key;
return get_key_by_rank(tr[p].r,rank-tr[p].cnt-tr[tr[p].l].size);
}

int get_prev(int p,int key)
{
if (!p) return -INF;
if (tr[p].key>=key) return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}

int get_next(int p,int key)
{
if (!p) return INF;
if (tr[p].key<=key) return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}
// 以上为 Treap 的模板
void make_tree(int p,int l,int r)
{
build(p);
seg[p].l=l,seg[p].r=r;
for (int i=l;i<=r;++i)
{
insert(seg[p].rt,a[i]);
}
if (l==r) return;
int mid=l+r>>1;
make_tree(l(p),l,mid);
make_tree(r(p),mid+1,r);
}

void seg_change(int p,int pos,int k)
{
remove(seg[p].rt,a[pos]);
insert(seg[p].rt,k);
if (seg[p].l==seg[p].r) return;
int mid=seg[p].l+seg[p].r>>1;
if (pos<=mid) seg_change(l(p),pos,k);
else seg_change(r(p),pos,k);
}

int seg_get_prev(int p,int l,int r,int x)
{
if (seg[p].l>=l&&seg[p].r<=r) return get_prev(seg[p].rt,x);
int mid=seg[p].l+seg[p].r>>1,ans=-INF;
if (l<=mid) ans=max(ans,seg_get_prev(l(p),l,r,x));
if (r>mid) ans=max(ans,seg_get_prev(r(p),l,r,x));
return ans;
}

int seg_get_next(int p,int l,int r,int x)
{
if (seg[p].l>=l&&seg[p].r<=r) return get_next(seg[p].rt,x);
int mid=seg[p].l+seg[p].r>>1,ans=INF;
if (l<=mid) ans=min(ans,seg_get_next(l(p),l,r,x));
if (r>mid) ans=min(ans,seg_get_next(r(p),l,r,x));
return ans;
}

int seg_get_rank_by_key(int p,int l,int r,int key)
{
if (seg[p].l>=l&&seg[p].r<=r) return get_rank_by_key(seg[p].rt,key)-1;
int mid=seg[p].l+seg[p].r>>1,ans=0;
if (l<=mid) ans+=seg_get_rank_by_key(l(p),l,r,key);
if (r>mid) ans+=seg_get_rank_by_key(r(p),l,r,key);
// printf("%d %d %d key=%d ans=%d\n",p,seg[p].l,seg[p].r,key,ans);
return ans;
}
//以上为线段树的模板
int main()
{
scanf("%d %d",&n,&m);
for (int i=1;i<=n;++i) scanf("%d",a+i);
make_tree(1,1,n);
/*remove(seg[2].rt,2);
printf("%d\n",get_rank_by_key(seg[2].rt,4)-1);*/
// printf("%d\n",seg_get_rank_by_key(1,1,4,INF/2));
int op,l,r,pos,x;
while (m--)
{
scanf("%d",&op);
if (op==3)
{
scanf("%d %d",&pos,&x);
seg_change(1,pos,x);
a[pos]=x;
}
else
{
scanf("%d %d %d",&l,&r,&x);
if (op==1) printf("%d\n",seg_get_rank_by_key(1,l,r,x)+1);
if (op==2)
{
int lval=-INF,rval=INF,tot=0;
while (lval<rval)
{
int mid=(long long)lval+rval+1>>1;
// cout<<lval<<' '<<mid<<' '<<rval<<'\t'<<seg_get_rank_by_key(1,l,r,mid)<<endl;
if (seg_get_rank_by_key(1,l,r,mid)>=x) rval=mid-1;
else lval=mid;
}
printf("%d\n",lval);
}
if (op==4) printf("%d\n",seg_get_prev(1,l,r,x));
if (op==5) printf("%d\n",seg_get_next(1,l,r,x));
}
}
return 0;
}

T3:K大数查询

题目传送门 Luogu

题目传送门 AcWing

首先,由于范围过大,我们需要离散化。

然后,我们要使用树套树,可以完成。但是有其他办法,所以没有写代码。

到时再贴树套树的做法吧。