给你一棵树,每个点有一个字符。再给你一个字符串 s。
然后问你树上的所有简单的路径在 s 上的出现次数的和。
一个比较神奇的题目。
首先考虑
n
2
n^2
n2 暴力,不难想象使用 SAM,试着给
s
s
s 串建一下 SAM 看看。
然后不难想象你就枚举根,然后跑树在 SAM 上跑链,然后出现次数其实就是 DAG 能走到的部分的
s
z
sz
sz 和。
(因为那些位置都是你的后缀)
然后考虑别的算法,因为是树上,我们试着用点分治。
那我们考虑把这条路径拆成了两个部分,那我们可以考虑枚举中间衔接的是
s
s
s 串的哪个位置。
(不能枚举颜色,因为可能是在不同的位置)
也就是说我们需要统计子树里面有多少个到根的链满足出现在
s
s
s 中而且末尾为
p
p
p(设
p
p
p 为你当前枚举的位置),还有跟到这个点的字符串开头为
p
p
p 的。
不难看出我们只需要求出一种,另一种反转一下串即可。
那一个比较容易想到的想法是记录一个数组
n
u
m
i
num_i
numi 表示 SAM 上
i
i
i 这个位置存了多少贡献。
对于每个那些串我们可以在 SAM 上跑(继承地跑),然后给对于的位置
n
u
m
i
num_i
numi 加一。
然后再 DP 一下
n
u
m
i
num_i
numi 加上祖先的值
n
u
m
f
a
i
num_{fa_i}
numfai 即可。
找答案的时候
s
s
s 中第
i
i
i 个位置在 SAM 的位置是
x
x
x,就直接是
n
u
m
x
num_x
numx 了。
但是会发现一个问题,它好像不能继承?
因为你是在前面加字符,不能直接用
s
o
n
i
son_i
soni 数组。
那考虑自己再弄一个数组
S
o
n
i
Son_i
Soni:
考虑看它那个 endpose 集合,因为是连续的,所以如果当前的长度不是最大长度,就一定要跟前面那个对应上,对应上了就还是这个位置,否则就没有了。
那如果是最大长度,就一定会往后跳,那就是要找到一个
f
a
fa
fa 是它的点,满足前面能补上你要的那个,否则也没有答案。
那我们只需要记录达到最大长度的那个跳到哪即可,前面那个我们直接可以那个时候推。
求这个
S
o
n
Son
Son 的时候也不难,就直接每个点给它的父亲的
S
o
n
Son
Son 数组贡献一下。
(所以我们这里需要记录每个点的 endpose 集合的右端点,随便一个即可,这里是
R
i
R_i
Ri)
然后看一下复杂度:
O
(
n
log
n
+
n
m
)
O(n\log n+nm)
O(nlogn+nm)(因为每个点你都要枚举
m
m
m 中间的断点位置)
麻了怎么更垃圾了。
但是观察到带上了
m
m
m,而
n
n
n 大看起来是没问题的。
于是考虑一下能不能两个算法都用上,分析一下各自适用处:
暴力就是
n
2
n^2
n2,
n
n
n 小的时候能用。
这个算法是
n
log
n
+
n
m
n\log n+nm
nlogn+nm,那
n
n
n 大的时候似乎就可以,因为
n
m
nm
nm 里面的
n
n
n 是你作为断点点数。
那我们考虑一下
⩽
n
\leqslant \sqrt{n}
⩽n 的用暴力,别的用第二个算法。
首先看暴力杂度,
T
(
n
)
=
2
T
(
n
2
)
+
O
(
n
2
)
T(n)=2T(\dfrac{n}{2})+O(n^2)
T(n)=2T(2n)+O(n2),主定理之类的能证明是
n
n
n\sqrt{n}
nn。
感性理解一下也不难,你每次只用
1
1
1 的大小是
O
(
n
)
O(n)
O(n),每次用
n
\sqrt{n}
n 的大小是
n
n
n
2
=
n
n
\dfrac{n}{\sqrt{n}}\sqrt{n}^2=n\sqrt{n}
nnn2=nn。
也不会用更大的因为会用另一个算法继续点分治。
接着看另一个算法,考虑根号,你深度就
log
n
\log n
logn 层,而且到
n
\sqrt{n}
n 大小的层就不走了,所以
n
m
nm
nm 里面的
n
n
n 就是
n
\sqrt{n}
n 级别的(大概)。
所以就是对的。
具体操作建议看看打代码。
小细节:
两个 SAM 的
s
s
s 数组是不一样的。(第二个是翻转的,后面也要用的
s
s
s 数组也是不一样的)
区分好
s
o
n
i
son_i
soni 和
S
o
n
i
Son_i
Soni 数组,暴力里面只会用到
s
o
n
i
son_i
soni!
SAM 经典必写错(好吧这个也许只有我
#include
#include
#include
#include
#include
#define ll long long
using namespace std;
const ll N = 5e4 + 100;
struct node {
ll to, nxt;
}e[N << 1];
ll n, m, le[N], KK, a[N], s[N], B;
ll root, max_root, sz[N], pla1[N], pla2[N];
char s1[N], s2[N];
bool in[N];
ll ans;
void add(ll x, ll y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
}
struct SAM {
struct nde {
ll sz, len, R, fa;
ll son[26], Son[26];
}d[N << 1];
ll tot, lst; ll num[N << 1];
ll tong[N], xl[N << 1], s[N];
void Init() {
tot = lst = 1;
}
ll insert(ll x) {
ll p = lst, np = ++tot; lst = np;
d[np].len = d[p].len + 1; d[np].sz = 1;
d[np].R = d[np].len;
for (; p && !d[p].son[x]; p = d[p].fa) d[p].son[x] = np;
if (!p) d[np].fa = 1;
else {
ll q = d[p].son[x];
if (d[q].len == d[p].len + 1) d[np].fa = q;
else {
ll nq = ++tot; d[nq] = d[q];
d[nq].len = d[p].len + 1; d[nq].sz = 0;
d[np].fa = d[q].fa = nq; d[nq].R = 0;
for (; p && d[p].son[x] == q; p = d[p].fa) d[p].son[x] = nq;
}
}
return np;
}
void build() {
for (ll i = 1; i <= m; i++) tong[i] = 0;
for (ll i = 1; i <= tot; i++) tong[d[i].len]++;
for (ll i = 1; i <= m; i++) tong[i] += tong[i - 1];
for (ll i = 1; i <= tot; i++) xl[tong[d[i].len]--] = i;
for (ll i = tot; i > 1; i--) {
ll now = xl[i];
d[d[now].fa].sz += d[now].sz;
d[d[now].fa].R = d[now].R;
d[d[now].fa].Son[s[d[now].R - d[d[now].fa].len]] = now;
}
}
void clear() {
for (ll i = 1; i <= tot; i++) num[i] = 0;
}
void clac(ll now, ll father, ll x, ll len) {
if (d[x].len == len) x = d[x].Son[a[now]];//到了最大长度
else if (s[d[x].R - len] != a[now]) x = 0;//没到最大长度,直接看是否对上字符
if (!x) return ; num[x]++;
for (ll i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to])
clac(e[i].to, now, x, len + 1);
}
void DP() {
for (ll i = 2; i <= tot; i++) {
ll now = xl[i];
num[now] += num[d[now].fa];
}
}
}S1, S2;
void dfs(ll now, ll father) {
sz[now] = 1;
for (ll i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to]) {
dfs(e[i].to, now); sz[now] += sz[e[i].to];
}
}
void get_root(ll now, ll father, ll sum) {
ll maxn = sum - sz[now];
for (ll i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to]) {
get_root(e[i].to, now, sum);
maxn = max(maxn, sz[e[i].to]);
}
if (maxn < max_root) max_root = maxn, root = now;
}
//
ll st[N];
void dfs1(ll now, ll father) {
st[++st[0]] = now;
for (ll i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to])
dfs1(e[i].to, now);
}
void dfs2(ll now, ll father, ll pl) {
pl = S1.d[pl].son[a[now]]; if (!pl) return ;
ans += S1.d[pl].sz;
for (ll i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to]) {
dfs2(e[i].to, now, pl);
}
}
void clacnn(ll now) {
st[0] = 0; dfs1(now, 0);
for (ll i = 1; i <= st[0]; i++) {
dfs2(st[i], 0, 1);
}
}
//n^2
void clac(ll now, ll father, ll op) {
S1.clear(); S2.clear();
if (father) {
S1.clac(now, 0, S1.d[1].Son[a[father]], 1);
S2.clac(now, 0, S2.d[1].Son[a[father]], 1);
}
else {
S1.clac(now, 0, 1, 0); S2.clac(now, 0, 1, 0);
}
S1.DP(); S2.DP();
for (ll i = 1; i <= m; i++)
ans += op * S1.num[pla1[i]] * S2.num[pla2[m - i + 1]];//记得第二个串翻转了所以第二个是 m-i+1
}
void slove(ll now) {
in[now] = 1;
dfs(now, 0);
if (sz[now] <= B) {//n^2
in[now] = 0;
clacnn(now);
in[now] = 1;
return ;
}
clac(now, 0, 1);
for (ll i = le[now]; i; i = e[i].nxt)
if (!in[e[i].to]) {
clac(e[i].to, now, -1);
max_root = sz[e[i].to] + 1;
get_root(e[i].to, now, sz[e[i].to]);
slove(root);
}
}
int main() {
scanf("%lld %lld", &n, &m); B = sqrt(n);
// B = 1;
for (ll i = 1; i < n; i++) {
ll x, y; scanf("%lld %lld", &x, &y);
add(x, y); add(y, x);
}
scanf("%s", s1 + 1); scanf("%s", s2 + 1);
for (ll i = 1; i <= n; i++) a[i] = s1[i] - 'a';
for (ll i = 1; i <= m; i++) s[i] = s2[i] - 'a';
S1.Init();
for (ll i = 1; i <= m; i++) pla1[i] = S1.insert(s[i]), S1.s[i] = s[i];
S1.build();
reverse(s + 1, s + m + 1);
S2.Init();
for (ll i = 1; i <= m; i++) pla2[i] = S2.insert(s[i]), S2.s[i] = s[i];
S2.build();
reverse(s + 1, s + m + 1);
dfs(1, 0);
max_root = n + 1;
get_root(1, 0, n);
slove(root);
printf("%lld", ans);
return 0;
}