前言

来个绝世唐诗做法,喜提题解区复杂度倒一。

思路

题意即对每个点求内向基环树森林里有多少个点走 kk 步以内能到达自己。

树上的点答案是好求的,对于环上的点,首先用 vector 记录其子树内每个深度的点各有多少个,然后考虑对其子树内最大深度进行根号分治(称最大深度 n\geq \sqrt n 的点为大点,n\leq \sqrt n 的点为小点),来计算环上不同点之间的贡献:

  • 大点个数不超过 n\sqrt n 个,故大点与其他点之间的贡献可以直接暴力枚举计算,复杂度 O(nn)\mathcal{O}(n\sqrt n)

  • 之后只需要计算小点与小点之间的贡献,考虑断环成链,顺序枚举所有点,每枚举到一个点会使之前的点与当前点距离加上 ilsti-lstlstlst 为上次枚举到的点的下标),于是相当于维护一个支持全局下标加 xx、查询下标在 kk 之前的前缀和的数据结构,可以使用优先队列进行懒删除,复杂度 O(nnlogn)\mathcal{O}(n\sqrt n\log n)。然而这样做只能计算环上 ij(i<j)i\rightarrow j(i<j) 的贡献,于是把这个过程再反向做一遍就可以了。

总复杂度 O(nnlogn)\mathcal{O}(n\sqrt n\log n),时限比较宽松,可以通过。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
vector<int> tr[N],p[N];
int n,k,fa[N],col[N],deg[N],T,cnt = 0;
bool ring[N];

void topo(){
for(int i=1;i<=n;i++) deg[fa[i]]++;
queue<int> q;
for(int i=1;i<=n;i++){
ring[i] = 1;
if(!deg[i]) q.push(i);
}
while(!q.empty()){
int u = q.front();q.pop();
ring[u] = 0,deg[fa[u]]--;
if(!deg[fa[u]]) q.push(fa[u]);
}
}

void dfs0(int u,int c){
col[u] = c,p[c].pb_(u);
for(int v:tr[u]){
if(col[v]) continue;
dfs0(v,c);
}
}

int dep[N],mxd[N],ans[N],rt[N];

deque<int> dq;

void dfs1(int u){
int F = 0;
dep[u] = (ring[u]?-1:dep[fa[u]])+1,rt[u] = ring[u]?u:rt[fa[u]],mxd[u] = dep[u];
ans[u]++;
if(dq.size()>k) F = dq.front(),ans[F]--,dq.pop_front();
dq.pb_(u);
for(int v:tr[u]){
if(v==fa[u] || ring[v]) continue;
dfs1(v),mxd[u] = max(mxd[u],mxd[v]);
}
if(F) dq.push_front(F);
dq.pob_();
}

void dfs2(int u){
for(int v:tr[u]){
if(v==fa[u] || ring[v]) continue;
dfs2(v),ans[u]+=ans[v];
}
}

vector<int> c[N],pre[N],nr;
bool vis[N];

void get_ring(int u){
if(vis[u]) return;
vis[u] = 1,nr.pb_(u),get_ring(fa[u]);
}

inline int dis(int i,int j){return i<j?j-i:(int)nr.size()-i+j;}

void calc(int i,int j){
int d = dis(i,j);
if(d>k) return;
ans[nr[j]]+=pre[nr[i]][min(k-d,mxd[nr[i]])];
}

void prt(){
for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
}

struct P{
int id,d,w;
inline bool operator <(const P &rhs)const{return w<rhs.w;}
inline bool operator >(const P &rhs)const{return w>rhs.w;}
};

void solve(int id){
for(int u:p[id]){
if(!ring[u]) continue;
dfs1(u);
}
nr = vector<int>();
for(int u:p[id]){
if(!ring[u]) continue;
get_ring(u);
break;
}
for(int u:nr) dfs2(u);
if(nr.size()==1) return;
for(int u:nr) c[u].resize(mxd[u]+1);
for(int u:p[id]) c[rt[u]][dep[u]]++;
for(int u:nr){
pre[u].resize(mxd[u]+1);
for(int i=0;i<=mxd[u];i++) pre[u][i] = (i?pre[u][i-1]:0)+c[u][i];
}
for(int i=0;i<nr.size();i++){
if(mxd[nr[i]]<T) continue;
for(int j=0;j<nr.size();j++){
if(i==j || mxd[nr[j]]<T) continue;
calc(i,j);
}
for(int j=0;j<nr.size();j++){
if(i==j || mxd[nr[j]]>=T) continue;
calc(i,j),calc(j,i);
}
}
int now = 0,d = 0,lst = -1;
priority_queue<P> pq;
for(int i=0;i<nr.size();i++){
int u = nr[i];
if(mxd[u]>=T) continue;
if(lst==-1){
lst = i;
for(int j=0;j<=mxd[u];j++) pq.push({u,j,j});
now+=pre[u][mxd[u]];
continue;
}
d+=dis(lst,i);
while(!pq.empty() && pq.top().w+d>k) now-=c[pq.top().id][pq.top().d],pq.pop();
ans[u]+=now;
for(int j=0;j<=mxd[u];j++) pq.push({u,j,j-d});
now+=pre[u][mxd[u]],lst = i;
}
pq = priority_queue<P>();
priority_queue<P,vector<P>,greater<P>> pq2;
now = 0,d = nr.size(),lst = -1;
for(int i=(int)nr.size()-1;i>=0;i--){
int u = nr[i];
if(mxd[u]>=T) continue;
if(lst==-1){
lst = i;
for(int j=0;j<=mxd[u];j++) pq2.push({u,j,j});
continue;
}
d-=dis(i,lst);
while(!pq2.empty() && pq2.top().w+d<=k) now+=c[pq2.top().id][pq2.top().d],pq2.pop();
ans[u]+=now;
for(int j=0;j<=mxd[u];j++) pq2.push({u,j,j+(int)nr.size()-d});
lst = i;
}
}

void solve(){
read(n,k),T = __builtin_sqrt(n);
for(int i=1;i<=n;i++) read(fa[i]),tr[fa[i]].pb_(i);
topo();
for(int i=1;i<=n;i++){
if(!ring[i] || col[i]) continue;
dfs0(i,++cnt);
}
for(int i=1;i<=cnt;i++) solve(i);
for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
}