可以暴力解决,但是为了锻炼一下ac自动机的编程,我们使用ac自动机。
ac自动机主要维护两个列表,一个列表ch,ch[f][idx]表示从父节点f向idx这个方向走,走到的节点。另一个列表nex,nex[i]表示节点i回跳边的节点。
- from collections import defaultdict, deque
- class Solution:
- def addBoldTag(self, s: str, words: List[str]) -> str:
- p = ''.join(words)
- n = len(p)
- n1 = len(s)
- r1 = []
- # ac自动机
- ch = defaultdict(dict)
- nex = [0]*(n+1)
- cnt = [0]*(n+1)
- idy = 0
- def assign(s):
- if s.isdigit():
- return int(s)
- elif ord('a') <= ord(s) <= ord('z'):
- return ord(s) - ord('a') + 10
- else:
- return ord(s) - ord('A') + 36
-
- def insert(word):
- f = 0
- nonlocal idy
- for s in word:
- idx = assign(s)
- if idx not in ch[f]:
- idy += 1
- ch[f][idx] = idy
- f = idy
- else:
- f = ch[f][idx]
- cnt[f] = len(word)
-
- def bulid():
- qu = deque()
- for i in range(62):
- if i not in ch[0]:
- continue
- qu.append(ch[0][i])
- while qu:
- f = qu.popleft()
- for i in range(62):
- if i not in ch[f]:
- ch[f][i] = 0 if i not in ch[nex[f]] else ch[nex[f]][i]
- else:
- nex[ch[f][i]] = 0 if i not in ch[nex[f]] else ch[nex[f]][i]
- qu.append(ch[f][i])
-
- def query(s):
- f = 0
- for i in range(n1):
- idx = assign(s[i])
- f = 0 if idx not in ch[f] else ch[f][idx]
- idy = f
- while idy != 0:
- if cnt[idy]:
- r1.append([i-cnt[idy]+1, i+1])
- break
- idy = nex[idy]
- return
-
-
- for word in words:
- insert(word)
- bulid()
- query(s)
- r1.sort()
- leth = len(r1)
- if leth == 0:return s
- rec = [r1[0]]
- for idx in range(1, leth):
- l, r = r1[idx]
- if l <= rec[-1][1]:
- rec[-1][1] = max(r, rec[-1][1])
- else:
- rec.append([l,r])
- dic = defaultdict(str)
- for l, r in rec:
- dic[l] = ''
- dic[r] = ''
- ans = ''
- for idx in range(n1):
- if dic[idx]:
- ans += dic[idx]
- ans += s[idx]
- if dic[n1]:
- ans += dic[n1]
- return ans
确实是通过了,但是!!!暴力解法居然比ac自动机更快!!!哪边出了问题???
下面是暴力的,上面是ac自动机
暴力代码:
- class Solution:
- def addBoldTag(self, s: str, words: List[str]) -> str:
- from collections import defaultdict
- r1 = []
- l1 = len(s)
- for word in words:
- l2 = len(word)
- for idx in range(l1-l2+1):
- if s[idx:idx+l2] == word:
- r1.append([idx,idx+l2])
- r1.sort()
- leth = len(r1)
- if leth == 0:return s
- rec = [r1[0]]
- for idx in range(1, leth):
- l, r = r1[idx]
- if l <= rec[-1][1]:
- rec[-1][1] = max(r, rec[-1][1])
- else:
- rec.append([l,r])
- dic = defaultdict(str)
- for l, r in rec:
- dic[l] = ''
- dic[r] = ''
- ans = ''
- for idx in range(l1):
- if dic[idx]:
- ans += dic[idx]
- ans += s[idx]
- if dic[l1]:
- ans += dic[l1]
- return ans