双向Trie

  |  

摘要: 在一个 Trie 节点同事维护前缀和后缀

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


$0 双向Trie

场景

一个 Trie 可以预处理一批单词的所有前缀,或者所有后缀。此后可以快速进行与前缀匹配或者后缀匹配有关的查询。

有时需求会更近一步:即做一些前缀和后缀均需要匹配的查询。例如问前缀为 profix 同时后缀为 suffix 的单词有哪些。

Trie 朴素方案

此时一种直接的方案是维护两个 Trie,一个保存前缀信息,一个保存后缀信息。在做同时要求前缀和后缀的查询 query(prefix, suffix) 时,先找到两颗树上对应的两个节点,然后在两个节点持有的信息中进行搜索,例如取交集等。

这种方法在取得两个节点后的搜索阶段可能产生瓶颈。

双向 Trie 方案

双向 Trie 是在 Trie 的基础上更进一步,一个节点同时维护前缀和后缀两个方向的扩展,每次各扩展一个字符。

例如 ('a', 'e'),当插入 apple 这个词的时候,('a', 'e') 就是可以扩展的节点。因为 a 是 apple 的前缀的下一个字符同时 e 是 apple 后缀的下一个字符。

前缀和后缀各有 26 个字符,但是需要考虑只扩展前缀或者只扩展后缀的情况,这两种情况可以视为另一个方向扩展了空字符,例如 (' ', 'a') 表示只扩展后缀字符 a,('b', ' ') 表示只扩展前缀字符 b。

这样相当于前缀和后缀均有 27 种选择,对应地共有 27 * 27 个子节点。

双向 Trie 实现

节点定义

如果用数组实现 children,一个节点有 27 * 27 个子节点,空间成本过高。因此改用哈希表实现 children

在推进前后缀的时候,如果前缀或后缀停了,即推进了空字符,则同方向后续也只能推进空字符,因此需要来两个标记来记录当前节点代表的分支上前缀以及后缀方向是否已经停了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct DoubleTrieNode
{
bool prefix_terminate;
bool suffix_terminate;
unordered_map<string, DoubleTrieNode*> children;
// 节点信息 node_info;
DoubleTrieNode()
{
prefix_terminate = false;
suffix_terminate = false;
children = unordered_map<string, DoubleTrieNode*>();
max_idx = -1;
}
};

插入

用两个指针 i,j 表示当前节点的前后缀对应在插入单词上的位置,s[i], s[j]

然后检查三种情况:

  • i 是否可以向右推进
  • j 是否可以向左推进
  • i, j 是否可以同时推进
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
void insert(const string& s, int w)
{
int n = s.size();
_insert(root, s, 0, n - 1, w);
}

void _insert(DoubleTrieNode* node, const string& s, int i, int j, const int w)
{
// 更新 node_info
int n = s.size();
if(!node -> prefix_terminate && !node -> suffix_terminate)
{
if(i <= n - 1 && j >= 0)
{
string nxt_char_set = string(1, s[i]) + s[j];
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
_insert(nxt, s, i + 1, j - 1, w);
}
}
if(i <= n - 1 && !node -> prefix_terminate)
{
string nxt_char_set = string(1, s[i]) + '#';
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
nxt -> suffix_terminate = true;
_insert(nxt, s, i + 1, j, w);
}
if(j >= 0 && !node -> suffix_terminate)
{
string nxt_char_set = '#' + string(1, s[j]);
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
nxt -> prefix_terminate = true;
_insert(nxt, s, i, j - 1, w);
}
}

查找

用两个指针表示当前节点的前后缀对应在前缀串和后缀串的位置 prefix[i], suffix[j]

i 向右推进,j 向左推进,如果可以同时推进则同时推进,直到 i,j 都耗尽。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
int search(const string& prefix, const string& suffix)
{
auto it = _find(prefix, suffix);
if(!it)
return -1;
// 节点信息 it -> node_info;
}

DoubleTrieNode* _find(const string& prefix, const string& suffix)
{
int np = prefix.size();
int ns = suffix.size();
int i = 0, j = ns - 1;
DoubleTrieNode *iter = root;
while(i < np && j >= 0)
{
string nxt_char_set = string(1, prefix[i]) + suffix[j];
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
++i;
--j;
}
while(i < np)
{
string nxt_char_set = string(1, prefix[i]) + '#';
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
++i;
}
while(j >= 0)
{
string nxt_char_set = '#' + string(1, suffix[j]);
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
--j;
}
return iter;
}

$1 题目

设计一个包含一些单词的特殊词典,并能够通过前缀和后缀来检索单词。

实现 WordFilter 类:

  • WordFilter(string[] words) 使用词典中的单词 words 初始化对象。
  • f(string pref, string suff) 返回词典中具有前缀 prefix 和后缀 suff 的单词的下标。如果存在不止一个满足要求的下标,返回其中 最大的下标 。如果不存在这样的单词,返回 -1 。

提示:

1
2
3
4
5
1 <= words.length <= 1e4
1 <= words[i].length <= 7
1 <= pref.length, suff.length <= 7
words[i]、pref 和 suff 仅由小写英文字母组成
最多对函数 f 执行 1e4 次调用

示例:
输入
[“WordFilter”, “f”]
[[[“apple”]], [“a”, “e”]]
输出
[null, 0]
解释
WordFilter wordFilter = new WordFilter([“apple”]);
wordFilter.f(“a”, “e”); // 返回 0 ,因为下标为 0 的单词:前缀 prefix = “a” 且 后缀 suff = “e” 。

算法1:双向 Trie

代码 (C++,模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
struct DoubleTrieNode
{
bool prefix_terminate;
bool suffix_terminate;
unordered_map<string, DoubleTrieNode*> children;
int max_idx;
DoubleTrieNode()
{
prefix_terminate = false;
suffix_terminate = false;
children = unordered_map<string, DoubleTrieNode*>();
max_idx = -1;
}
};

class DoubleTrie
{
public:
DoubleTrie()
{
root = new DoubleTrieNode();
}

~DoubleTrie()
{
if(root)
_delete_sub_tree(root);
}

void insert(const string& s, int w)
{
int n = s.size();
_insert(root, s, 0, n - 1, w);
}

int search(const string& prefix, const string& suffix)
{
auto it = _find(prefix, suffix);
if(!it)
return -1;
return it -> max_idx;
}

private:
DoubleTrieNode *root;

void _delete_sub_tree(DoubleTrieNode* node)
{
for(auto child: node -> children)
if(child.second)
_delete_sub_tree(child.second);
delete node;
node = nullptr;
}

DoubleTrieNode* _find(const string& prefix, const string& suffix)
{
int np = prefix.size();
int ns = suffix.size();
int i = 0, j = ns - 1;
DoubleTrieNode *iter = root;
while(i < np && j >= 0)
{
string nxt_char_set = string(1, prefix[i]) + suffix[j];
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
++i;
--j;
}
while(i < np)
{
string nxt_char_set = string(1, prefix[i]) + '#';
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
++i;
}
while(j >= 0)
{
string nxt_char_set = '#' + string(1, suffix[j]);
DoubleTrieNode *&nxt = (iter -> children)[nxt_char_set];
if(!nxt)
return nullptr;
iter = nxt;
--j;
}
return iter;
}

void _insert(DoubleTrieNode* node, const string& s, int i, int j, const int w)
{
node -> max_idx = max(node -> max_idx, w);
int n = s.size();
if(!node -> prefix_terminate && !node -> suffix_terminate)
{
if(i <= n - 1 && j >= 0)
{
string nxt_char_set = string(1, s[i]) + s[j];
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
_insert(nxt, s, i + 1, j - 1, w);
}
}
if(i <= n - 1 && !node -> prefix_terminate)
{
string nxt_char_set = string(1, s[i]) + '#';
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
nxt -> suffix_terminate = true;
_insert(nxt, s, i + 1, j, w);
}
if(j >= 0 && !node -> suffix_terminate)
{
string nxt_char_set = '#' + string(1, s[j]);
DoubleTrieNode *&nxt = (node -> children)[nxt_char_set];
if(!nxt)
nxt = new DoubleTrieNode();
nxt -> prefix_terminate = true;
_insert(nxt, s, i, j - 1, w);
}
}
};

class WordFilter {
public:
WordFilter(vector<string>& words) {
trie = new DoubleTrie();
int n = words.size();
for(int i = 0; i < n; ++i)
{
string w = words[i];
trie -> insert(w, i);
}
}

int f(string prefix, string suffix) {
return trie -> search(prefix, suffix);
}

private:
DoubleTrie *trie;
};

算法2:前缀Trie + 后缀Trie

维护一个前缀树和一个后缀树。对于每个单词 words[i],在前缀树和后缀树中均做预处理:

words[i] 所有前缀节点和后缀节点记录下标 i

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
const int ALPHABET = 26;

struct TrieNode
{
vector<TrieNode*> children;
set<int> word_ids;
TrieNode()
{
children = vector<TrieNode*>(ALPHABET);
word_ids = set<int>();
}
};

class Trie
{
public:
Trie()
{
root_prefix = new TrieNode();
root_suffix = new TrieNode();
}

~Trie()
{
if(root_prefix)
_delete_sub_tree(root_prefix);
if(root_suffix)
_delete_sub_tree(root_suffix);
}

void insert_prefix(const string& s, int w)
{
TrieNode *iter = root_prefix;
(iter -> word_ids).insert(w);
for(const char &ch: s)
{
if(!(iter -> children)[ch - 'a'])
(iter -> children)[ch - 'a'] = new TrieNode();
iter = (iter -> children)[ch - 'a'];
(iter -> word_ids).insert(w);
}
}

void insert_suffix(const string& s, int w)
{
TrieNode *iter = root_suffix;
(iter -> word_ids).insert(w);
int n = s.size();
for(int i = n - 1; i >= 0; --i)
{
const char &ch = s[i];
if(!(iter -> children)[ch - 'a'])
(iter -> children)[ch - 'a'] = new TrieNode();
iter = (iter -> children)[ch - 'a'];
(iter -> word_ids).insert(w);
}
}

int search(const string& prefix, const string& suffix)
{
auto p1 = find(root_prefix, prefix);
auto p2 = find(root_suffix, suffix);
if(!p1 || !p2)
return -1;
set<int>& words_prefix = p1 -> word_ids;
set<int>& words_suffix = p2 -> word_ids;
auto it = words_prefix.rbegin();
while(it != words_prefix.rend())
{
if(words_suffix.count(*it) > 0)
return *it;
++it;
}
return -1;
}

private:
TrieNode *root_suffix;
TrieNode *root_prefix;

TrieNode* find(TrieNode* root, const string& prefix)
{
TrieNode *iter = root;
for(const char &ch: prefix)
{
if(!(iter -> children)[ch - 'a'])
return nullptr;
iter = (iter -> children)[ch - 'a'];
}
return iter;
}

void _delete_sub_tree(TrieNode* node)
{
for(TrieNode *child: node -> children)
{
if(child)
_delete_sub_tree(child);
}
delete node;
node = nullptr;
}
};

class WordFilter {
public:
WordFilter(vector<string>& words) {
trie = new Trie();
int n = words.size();
for(int i = 0; i < n; ++i)
{
string w = words[i];
trie -> insert_prefix(w, i);
trie -> insert_suffix(w, i);
}
}

int f(string prefix, string suffix) {
reverse(suffix.begin(), suffix.end());
return trie -> search(prefix, suffix);
}

private:
Trie *trie;
};

Share