D. “a” String Problem
题意
给定一个字符串 s s s,要求把 s s s 拆分成若干段,满足以下要求:
- 拆分出来的每一个子段,要么是子串 t t t,要么是字符 a a a
- 子串 t t t 至少出现一次
- t ≠ " a " t \neq "a " t="a"
问有多少种不同的子串 t t t 满足以上要求
思路
如果 s s s 全是 a a a 的话,假设 ∣ s ∣ = n |s| = n ∣s∣=n,那么答案是: n − 1 n - 1 n−1
否则通过简单观察我们可以发现: t t t 必须包含非 a a a 字符(不一定是所有,只要 t t t 可以覆盖这些非 a a a 字符即可)
假设 s s s 从左到右第一个出现非 a a a 字符的位置是 p 0 p_0 p0,如果我们先固定 t t t 的开头在 p 0 p_0 p0,
我们就可以先枚举 t t t 的长度从 1 → n − p 0 + 1 1 \rarr n - p_0 + 1 1→n−p0+1,那么如何确定当前 t t t 是否满足上述要求?
我们直接在第一个 t t t 的末尾后面,找第一个出现的非 a a a 字符作为第二个 t t t 的开头,然后这后面 l e n len len 个字符必须与 t t t 相等,如果相等则继续往后检查,否则当前 t t t 无效。
我们只需要预处理每一个位置后面第一个非 a a a 字符的位置就可以,倒着扫一遍就可以线性预处理出来。
匹配的过程我们可以使用 Z Z \; Z函数,这是由于本质上是从 p 0 p_0 p0 开始的字符串,拿它去和它自己本身的每个后缀做匹配,自然可以使用 Z Z \; Z 函数。
那么这个检查过程是: O ( n log n ) O(n \log n) O(nlogn) 的(调和级数复杂度)
那么现在问题在于: t t t 不一定是以非 a a a 字符开头的。
其实这个问题很容易处理,假设我们当前有效的从 p 0 p_0 p0 开头的 ∣ t ∣ = l e n |t| = len ∣t∣=len,那么我们在检查过程的同时,记录每个 t t t 的前面到前一个 t t t 的末尾,有多少个 a a a,统计这个最小值 m n mn mn
那么很显然,当前的 t t t 可以往前扩展最多 m n mn mn 个 a a a,最后还是有效的。
那么这里的方案数就是: m n + 1 mn + 1 mn+1
所以答案最后对于每个有效长度累加即可
时间复杂度: O ( n log n ) O(n \log n) O(nlogn)
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()const int INF=0x3f3f3f3f;
const long long INFLL=0x3f3f3f3f3f3f3f3fLL;typedef long long ll;std::vector<int> z_function(const std::string& s, int n){std::vector<int> z(n + 1, 0);z[1] = n;int l = 0, r = 0;fore(i, 2, n + 1){if(i <= r) z[i] = std::min(z[i - l + 1], r - i + 1);while(i + z[i] <= n && s[1 + z[i]] == s[i + z[i]])++z[i];if(i + z[i] - 1 > r){l = i;r = i + z[i] - 1;}}return z;
}int main(){std::ios::sync_with_stdio(false);std::cin.tie(nullptr);std::cout.tie(nullptr);int t;std::cin >> t;while(t--){std::string s;std::cin >> s;int n = s.size();s = '0' + s;std::vector<int> nxt_nona(n + 5, n + 5);for(int i = n; i > 0; --i){if(s[i] == 'a') nxt_nona[i] = nxt_nona[i + 1];else nxt_nona[i] = i;}if(nxt_nona[1] > n){ //全是astd::cout << n - 1 << endl;continue;}int p0 = nxt_nona[1];std::string T = s.substr(p0);T = '0' + T;auto z = z_function(T, T.size() - 1);ll ans = 0;fore(len, 1, n - p0 + 2){int mn = p0 - 1;bool ok = true;int lst = p0 + len - 1;for(int j = nxt_nona[p0 + len]; j <= n; j = nxt_nona[j + len]){if(z[j - p0 + 1] < len){ok = false;break;}mn = std::min(mn, j - lst - 1);lst = j + len - 1;}if(ok) ans += mn + 1;}std::cout << ans << endl;}return 0;
}