本文共 2223 字,大约阅读时间需要 7 分钟。
题目大意: 问两个字符串有多少对相同的子串(位置不同算两对)。
以一个字符串造SAM,然后把另一个字符串丢到上面去匹配。
具体是枚举一个 i i i,然后看字符串 B B B 的前 i i i 个字符组成的子串,找到这个子串的最长的 是 A A A 字符串的子串的 后缀。
假设这段后缀是 x x x ~ i i i,那么可以得出, y y y ~ i ( y ∈ [ x , i ] ) i~(y\in[x,i]) i (y∈[x,i]) 这些后缀都是 A A A 的一个子串,而每个后缀的贡献就是他在 A A A 中的出现次数,这个可以利用 A A A 的SAM来统计。
说白了,就是要找到这个子串在SAM上的对应状态,然后这个状态的 e n d p o s endpos endpos 集大小就是了(代码中 e n d p o s endpos endpos 集大小记录在 s i z e size size 数组里)。
但是显然不可能枚举每一个后缀去匹配,我们维护一个 n o w now now 和 l e n len len, n o w now now 表示现在匹配到SAM的哪个状态, l e n len len 记录前 i − 1 i-1 i−1 位最多成功匹配多少位,然后如果 n o w now now 没有通过字符 B i B_i Bi 能到达的后继状态,就往后继链接走,直到有能通过 B i B_i Bi 到达的后置状态为止,然后更新一下 l e n len len,这个状态的贡献就是 ( l e n − l e n ( l i n k ( n o w ) ) ) × s i z e [ n o w ] (len-len(link(now)))\times size[now] (len−len(link(now)))×size[now],然后还要加上后继链接能到达的那些状态的贡献,可以预处理然后 O ( 1 ) O(1) O(1) 读取。
具体实现就看代码吧:
#include#include #include using namespace std;#define maxn 400010#define ll long longint n;char s[maxn];struct state{ int len,link,next[26];}st[maxn];int id=0,last=0,now,p,q;int size[maxn];ll sum[maxn];void extend(int x){ now=++id; st[now].len=st[last].len+1;size[now]=1; for(p=last;p!=-1&&!st[p].next[x];p=st[p].link)st[p].next[x]=now; if(p!=-1) { q=st[p].next[x]; if(st[p].len+1==st[q].len)st[now].link=q; else { int clone=++id; st[clone]=st[q];st[clone].len=st[p].len+1; for(;p!=-1&&st[p].next[x]==q;p=st[p].link)st[p].next[x]=clone; st[q].link=st[now].link=clone; } } last=now;}int c[maxn],A[maxn];void work(){ for(int i=1;i<=id;i++)c[st[i].len]++; for(int i=1;i<=id;i++)c[i]+=c[i-1]; for(int i=1;i<=id;i++)A[c[st[i].len]--]=i; for(int i=id;i>=1;i--)size[st[A[i]].link]+=size[A[i]]; for(int i=1;i<=id;i++)sum[A[i]]+=sum[st[A[i]].link]+size[A[i]]*(st[A[i]].len-st[st[A[i]].link].len);}ll ans=0;void solve(){ int now=0,len=0; for(int i=1;i<=n;i++) { while(now&&!st[now].next[s[i]-'a'])now=st[now].link; len=min(len,st[now].len)+1;now=st[now].next[s[i]-'a']; if(now!=0)ans+=(len-st[st[now].link].len)*size[now]+sum[st[now].link]; }}int main(){ scanf("%s",s+1);n=strlen(s+1);st[0].link=-1; for(int i=1;i<=n;i++)extend(s[i]-'a'); work(); scanf("%s",s+1);n=strlen(s+1); solve();printf("%lld",ans);}
转载地址:http://ljnib.baihongyu.com/