点分治和点分树

将来自数组的分治搬到树上。

1. 主要思想

树上分治分为点分治和边分治。

边分治主要因为时间复杂度容易被卡为 $O(n)$,而点分治可以保证时间复杂度为 $O(\log n)$。

所以边分治不太常用,而点分治相对常用。

点分树是点分治的动态问题。

2. 主要方法

以下面一个例题为例:

T1:【模板】点分治

题目传送门 Luogu

题目传送门 AcWing

其实,点分治就是取一个点,然后分治为几棵子树。

关键是如何统计不同子树间的信息。

首先,将所有点到根节点全部存下来。

为了处理,我们可以先将所有的排序,然后使用双指针即可。

要统计不同子树间的,我们可以先将所有的情况减去不符合条件的情况即可。

注意,因为要递归,我们希望层数尽量少。

所以,我们选择树的重心,这样不会超过 $\log n$ 层。

时间复杂度为 $O(n\log^2 n)$。

由于时间都卡得很紧,需要一定优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#define R register
using namespace std;

typedef long long ll;
const int N=1e4+10,Maxn=1e7+10;

int h[N],e[2*N],ne[2*N],w[2*N],idx=1,s[N],rt,mx=1e7,son[N],sizetot,st[N];
int n,m,k,top;
ll ans=0;
bool vis[N];

inline void add(R int a,R int b,R int c)
{
e[++idx]=b,w[idx]=c,ne[idx]=h[a],h[a]=idx;
}

inline void findroot(R int x,R int fa)
{
// cout<<x<<endl;
s[x]=1;
for (R int i=h[x];i!=-1;i=ne[i])
{
if (e[i]==fa||vis[e[i]]) continue;
findroot(e[i],x);
s[x]+=s[e[i]];
son[x]=max(son[x],s[e[i]]);
}
son[x]=max(son[x],sizetot-s[x]);
if (son[x]<mx) mx=son[x],rt=x;
}

inline void query(R int x,R int fa,R int d)
{
st[++top]=d;
for (R int i=h[x];i!=-1;i=ne[i])
{
if (vis[e[i]]||e[i]==fa) continue;
query(e[i],x,d+w[i]);
}
}

inline void solve(R int x,R int d,R int f)
{
top=0;query(x,0,d);
sort(st+1,st+top+1);
R int j=top;
ll now=0;
for (R int i=1;i<=top;++i)
{
while (j&&st[j]+st[i]>k) j--;
now+=j;
}
for (int i=1;i<=top;++i)
if (st[i]*2<=k) now--;
now/=2;
ans+=now*f;
}

inline void dfs(R int x)
{
if (k==1000) cout<<x<<endl;
vis[x]=1;
solve(x,0,1);
for (R int i=h[x];i!=-1;i=ne[i])
{
if (vis[e[i]]) continue;
solve(e[i],w[i],-1);
mx=1e7,rt=0,sizetot=s[e[i]];
findroot(e[i],x);
dfs(rt);
}
}

int main()
{
while (scanf("%d %d",&n,&k),n||k)
{
ans=0;idx=0;
for (R int i=1;i<=n;++i) son[i]=0,vis[i]=0,h[i]=-1;
for (R int i=1,x,y,c;i<n;++i) scanf("%d%d%d",&x,&y,&c),x++,y++,add(x,y,c),add(y,x,c);

mx=1e7,rt=0,sizetot=n;
findroot(1,0);dfs(rt);

cout<<ans<<endl;

}

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N=1e4+10,M=2e4+10;
int h[N],e[M],ne[M],w[M],idx;
int p[N],q[N],n,m;
bool vis[N];
int ans[N],que[N];

void add(int a,int b,int c)
{
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}

int get_size(int x,int fa)
{
if (vis[x]) return 0;
int res=1;
for (int i=h[x];~i;i=ne[i])
if (e[i]!=fa) res+=get_size(e[i],x);
return res;
}

int get_zx(int x,int fa,int tot,int &zx)
{
if (vis[x]) return 0;
int sum=1,mx=0;
for (int i=h[x];~i;i=ne[i])
{
if (vis[e[i]]||e[i]==fa) continue;
int t=get_zx(e[i],x,tot,zx);
sum+=t;mx=max(mx,t);
}
mx=max(mx,tot-sum);
if (mx<=tot/2) zx=x;
return sum;
}

void query(int x,int fa,int d,int &tot)
{
// cout<<x<<endl;
if (vis[x]) return;
q[tot++]=d;
for (int i=h[x];~i;i=ne[i])
if (e[i]!=fa) query(e[i],x,d+w[i],tot);
}

void solve(int *a,int tot,int f)
{
sort(a,a+tot);
for (int l=1;l<=m;++l)
{
int &k=que[l],res=0;
for (int i=tot-1,j=0;~i;--i)
{
while (j<tot&&a[j]+a[i]<k) j++;
if (a[j]+a[i]>k) continue;
while (j<tot&&a[j]+a[i]==k)
{
j++;
if (j==i+1) continue;
res++;
}
}
ans[l]+=res*f/2;
}
return ;
}

void calc(int x)
{
if (vis[x]) return ;
get_zx(x,-1,get_size(x,-1),x);
vis[x]=1;
int cnt=0,res=0;
for (int i=h[x];~i;i=ne[i])
{
int tot=0;
query(e[i],-1,w[i],tot);
solve(q,tot,-1);
for (int j=0;j<tot;++j)
{
for (int l=1;l<=m;++l)
if (que[l]==q[j]) ans[l]++;
p[cnt++]=q[j];
}
}
solve(p,cnt,1);
for (int i=h[x];~i;i=ne[i]) calc(e[i]);
return ;
}

int main()
{
scanf("%d %d",&n,&m);
memset(h,-1,sizeof h);
memset(vis,0,sizeof vis);idx=0;
for (int i=1,x,y,c;i<n;++i)
{
scanf("%d %d %d",&x,&y,&c);
add(x,y,c);add(y,x,c);
}
for (int i=1;i<=m;++i) scanf("%d",&que[i]);
calc(1);
for (int i=1;i<=m;++i) puts(ans[i]?"AYE":"NAY");
return 0;
}

3. 例题

T2:权值

题目传送门 AcWing

其实和上面的类似,我们可以分为一个一个的子树。

在同一棵子树中,可以递归处理,我们需要考虑的就是不同子树之间的距离。

可以直接处理每一个节点到根节点的距离,存在一个桶中。

其中,桶中存的是每一个对应距离,最小经过的边。

这样就可以了。

注意,我们无需每次清空桶,只需每一次将用过的点的操作撤销即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mp make_pair
#define fi first
#define se second
using namespace std;

typedef pair<int,int> PII;
const int N=2e5+10,M=4e5+20,Maxn=1e7+10,INF=0x3f3f3f3f;

int h[N],e[M],ne[M],w[M],idx;
int f[Maxn],n,m,ans=INF;
PII p[N],q[N];
bool vis[N];

void add(int a,int b,int c)
{
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}

int get_size(int x,int fa)
{
if (vis[x]) return 0;
int res=1;
for (int i=h[x];~i;i=ne[i])
{
if (e[i]==fa) continue;
res+=get_size(e[i],x);
}
return res;
}

int get_zx(int x,int fa,int tot,int &zx)
{
if (vis[x]) return 0;
int sum=1,mx=0;
for (int i=h[x];~i;i=ne[i])
{
if (e[i]==fa) continue;
int t=get_zx(e[i],x,tot,zx);
sum+=t;mx=max(mx,t);
}
mx=max(mx,tot-sum);
if (mx<=tot/2) zx=x;
return sum;
}

void query(int x,int fa,int d,int now,int &tot)
{
if (vis[x]||d>m) return ;
q[tot++]=mp(d,now);
for (int i=h[x];~i;i=ne[i])
if (e[i]!=fa) query(e[i],x,d+w[i],now+1,tot);
}

void calc(int x)
{
if (vis[x]) return;
get_zx(x,-1,get_size(x,-1),x);
vis[x]=true;
int cnt=0;
for (int i=h[x];~i;i=ne[i])
{
int tot=0;
query(e[i],x,w[i],1,tot);
for (int l=0;l<tot;++l)
{
PII &tmp=q[l];
if (tmp.fi==m) ans=min(ans,tmp.se);
else ans=min(ans,tmp.se+f[m-tmp.fi]);
p[cnt++]=tmp;
}

for (int l=0;l<tot;++l) f[q[l].fi]=min(f[q[l].fi],q[l].se);
}
for (int i=0;i<cnt;++i) f[p[i].fi]=INF;
// cout<<x<<' '<<ans<<endl;
for (int i=h[x];~i;i=ne[i]) calc(e[i]);
}

int main()
{
memset(h,-1,sizeof h);
memset(f,0x3f,sizeof f);
cin>>n>>m;
for (int i=1,x,y,c;i<n;++i)
{
scanf("%d %d %d",&x,&y,&c);
add(x,y,c);add(y,x,c);
}
calc(0);
if (ans!=INF) cout<<ans<<endl;
else puts("-1");
return 0;
}

4. 点分树

还是来看一下例题:

T3:开店

题目传送门 Luogu

题目传送门 AcWing

点分树所处理的题目没有更改树的形态,而是有很多的在线询问,需要我们回答距离问题。

还是考虑递归。

首先,假设 u 在一个子树,如果当前处理的节点在 u 所在的子树内,那么我们可以递归。

如果在不同节点,我们就可以通过归并的方法来解决。

由于没有更改,我们可以先预处理重心来划分。

可以将重心连接起来,我们发现又是一棵树。

这也是“点分树”的命名来源。

然后,我们应该如何计算所有点到 u 的距离和呢?

有两种情况:

1: 与兄弟子树(即 u 不是重心)。

首先,我们可以将答案分为两部分:u 到重心的距离乘以个数,再加上兄弟子树到重心的总距离。

我们可以将所有节点的年龄以及到重心的距离按年龄排序,直接二分即可,就可以求第一个了。

对于第二个,我们可以预处理前缀和。

2: 与所有子树(即 u 是重心)。

其实和上面一种情况比较相似,我们直接对于每一个子树,按上面的求就是了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;

typedef long long ll;
const int N=150010,M=300010,INF=0x3f3f3f3f;
int h[N],e[M],ne[M],w[M],idx;
int a[N],n,q,atot;
bool vis[N];
ll res;

struct Father{
int zx,now;
ll d;
};
vector <Father> f[N];

struct Son{
int a;ll d;
const bool operator <(const Son &t)const{
return a<t.a;
}
};
vector <Son> s[N][3];

void add(int a,int b,int c)
{
e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;
}

int get_size(int x,int fa)
{
if (vis[x]) return 0;
int res=1;
for (int i=h[x];~i;i=ne[i])
if (e[i]!=fa) res+=get_size(e[i],x);
return res;
}

int get_zx(int x,int fa,int tot,int &zx)
{
if (vis[x]) return 0;
int sum=1,maxn=0;
for (int i=h[x];~i;i=ne[i])
{
if (e[i]==fa) continue;
int t=get_zx(e[i],x,tot,zx);
maxn=max(maxn,t);sum+=t;
}
maxn=max(maxn,tot-sum);
if (maxn<=tot/2) zx=x;
return sum;
}

void get_son_tree(int x,int fa,ll d,int zx,int k,vector<Son> &p)
{
if (vis[x]) return;
f[x].push_back((Father){zx,k,d});
p.push_back((Son){a[x],d});
for (int i=h[x];~i;i=ne[i])
if (e[i]!=fa) get_son_tree(e[i],x,d+w[i],zx,k,p);
return;
}

void calc(int x)
{
if (vis[x]) return;
get_zx(x,-1,get_size(x,-1),x);
vis[x]=1;
for (int i=h[x],now=0;~i;i=ne[i])
{
if (vis[e[i]]) continue;
vector<Son> &p=s[x][now];
get_son_tree(e[i],x,w[i],x,now,p);
p.push_back((Son){-1,0});
p.push_back((Son){INF,0});
sort(p.begin(),p.end());
for (int i=1;i<p.size();++i) p[i].d+=p[i-1].d;
now++;
}
for (int i=h[x];~i;i=ne[i]) calc(e[i]);
}

void query(int x,int l,int r)
{
res=0;
for (int i=0;i<f[x].size();++i)
{
Father &tmp=f[x][i];
if (a[tmp.zx]>=l&&a[tmp.zx]<=r) res+=tmp.d;
for (int now=0;now<3;++now)
{
vector<Son> &p=s[tmp.zx][now];
if (now==tmp.now||p.empty()) continue;
int tl=lower_bound(p.begin(),p.end(),(Son){l,0})-p.begin(),
tr=lower_bound(p.begin(),p.end(),(Son){r+1,0})-p.begin();
res+=(tr-tl)*tmp.d+p[tr-1].d-p[tl-1].d;
}
}
for (int now=0;now<3;++now)
{
vector<Son> &p=s[x][now];
if (p.empty()) continue;
int tl=lower_bound(p.begin(),p.end(),(Son){l,0})-p.begin(),
tr=lower_bound(p.begin(),p.end(),(Son){r+1,0})-p.begin();
res+=p[tr-1].d-p[tl-1].d;
}
}

int main()
{
memset(h,-1,sizeof h);
cin>>n>>q>>atot;
for (int i=1;i<=n;++i) scanf("%d",a+i);
for (int i=1,x,y,c;i<n;++i)
{
scanf("%d %d %d",&x,&y,&c);
add(x,y,c);add(y,x,c);
}
calc(1);
int x,l,r;
while (q--)
{
scanf("%d %d %d",&x,&l,&r);
l=(l+res)%atot;r=(r+res)%atot;
if (l>r) swap(l,r);
query(x,l,r);
printf("%lld\n",res);
}
return 0;
}