前言

感觉这个题跟 P7114 [NOIP2020] 字符串匹配 莫名有点联系啊,我如果没写过那题是肯定不会做这个题的。

思路

一开始先特判掉全 a 串,显然此时答案为长度 1-1,以下讨论全部默认给定的串不是全 a 串。

性质 1:一个全 a 串一定不能成为答案,证明显然。

性质 2:一个合法的答案串一定是原串的一段前缀删去前缀的若干个 a

证明:设原串 ss 中第一个非 a 的位置为 pp,一个不满足该性质的 ss 的子串 t=slsl+1srt=s_ls_{l+1}\cdots s_r(不妨设 lltt 第一次出现的左端点),则条件等价于 p<lp<l。如果 [1,l)[1,l) 能被合法划分,由于 sps_p≠ a,这段字符串中必然有一段 [l,r]=t (1lpr<l)[l',r']=t\ (1\leq l'\leq p\leq r'<l),否则第 pp 个字符将无法被划分进任何一个子串。

性质 3:如果 [l,r](l<p)[l,r](l<p) 是合法的答案串,则 [l+1,r][l+1,r] 同样是。

证明:因为 l<pl<psls_l 必定为 a。于是将每个等于 [l,r][l,r] 的子串开头单分出来一个 a 就能形成一种合法的划分。

有了这些性质,我们很容易想到做法:枚举每个前缀 [1,i][1,i],先验证 [p,i][p,i] 这个开头没有 a 的串是否合法,若合法则求出最多能往串左边塞几个 a

实现上,先预处理出整串的哈希和每个下标 ii 右边第一个非 a 下标 nxtinxt_i,验证一个长度为 lenlen 的串时,直接用哈希查询然后每次往右跳 lenlen 个字符,若当前字符 sis_ia 则直接跳到 nxtinxt_i 即可。同时记录上一次查询的右端点,把每次跳到的左端点和上一个右端点之间 a 的个数取 min\min,即为最多能往左塞的 a 的个数。由于每次保底跳 lenlen 个字符,时间复杂度为调和级数约 O(nlnn)\mathcal{O}(n\ln n)

记得多搞几个哈希模数,实测使用 109+7,99824435310^9+7,998244353 两个模数都会被卡。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
constexpr int N = 2e5+5;
constexpr uint mod[3] = {1e9+7,998244353,19260817};

string s;
int n,nxt[N];
uint hs[N][3],b[N][3];

inline uint gethash(int l,int r,int m){return ((ll)hs[r][m]-(1ll*hs[l-1][m]*b[r-l+1][m]%mod[m])+mod[m])%mod[m];}

void init(){
n = s.length();
for(int j=0;j<=2;j++){
for(int i=1;i<=n;i++) hs[i][j] = (1ull*hs[i-1][j]*26+s[i-1])%mod[j];
}
}

int check(int l,int r){
int mn = inf;
for(int m=0;m<=2;m++){
uint h = gethash(l,r,m);
int len = (r-l+1);
int lst = r;
for(int i=r+1;i<=n;i+=len){
if(s[i-1]=='a'){i = nxt[i]-len;continue;}
mn = min(mn,i-lst-1);
lst = i+len-1;
uint nh = gethash(i,min(i+len-1,n),m);
if(nh!=h) return -1;
}
}
return mn;
}

void solve(){
cin>>s;
int cnta = 0;
for(char c:s) cnta+=(c=='a');
if(cnta==s.length()){cout<<s.length()-1<<"\n";return;}
init();
stack<int> st;
for(int i=1;i<=n;i++){
nxt[i] = n+1;
if(s[i-1]=='a') st.push(i);
else{
while(!st.empty()){
nxt[st.top()] = i;
st.pop();
}
}
}
int p = 1;
while(s[p-1]=='a') ++p;
ll ans = 0;
for(int i=p;i<=n;i++) ans+=min(check(p,i)+1,p);
cout<<ans<<"\n";
}

void _init(){
for(int j=0;j<=2;j++){
b[0][j] = 1;
for(int i=1;i<N;i++) b[i][j] = 1ull*b[i-1][j]*26%mod[j];
}
}