具体实现的话,我们用 file.tell() 的方式获取 bytes 数量后,直接先均分成 num 块(或者按内存大小计算块的大小,然后反过来计算块的数量)。然后,对每一块寻找下一个 special token 出现的位置,调整这一块的 boundary,这样,boundary 的每一个 int 都表示了 special token 的位置,也就满足了对 special token 出现位置的约束。最后排序去重就是最终的 boundary 了。
bytes 数量差了 2 倍!所以我们需要基于“很多单词是重复的,如果单词出现 a 次,那么这个单词所构成的 byte-pairs 也至少出现 a 次”这一观察,对 chunk of text 进行 Pre-Tokenization,得到词频统计。而词频统计字典的单词数通常比 byte string 长度短太多了。
# Pretokenization, using split # regex.escape 用来转义 <|endoftext|> 中的竖线,正则里的竖线表示“或者” sentences = regex.split( "|".join(regex.escape(special) for special in specials), data, )
# 正则分割出 pre-token 然后计数 PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" tokens_count = defaultdict(int) for sentence in sentences: # 注意这里必须逐个句子处理 # 如果 "\n".join 的话,可能 join 的 "\n" 也会被当成 token 的一部分。 pre_tokens = regex.finditer( PAT, sentence, ) for token in pre_tokens: tokens_count[token.group(0)] += 1
assertnotany(tok in tokens_count for tok in specials) # assert '<|' not in tokens_count
# > 这一步,我们先把所有 str 表示的单词转成 tuple of bytes toks = defaultdict(int) for token, count in tokens.items(): btoken = token.encode("utf-8") toks[tuple(btoken[i : i + 1] for i inrange(len(btoken)))] = count
# > 先插入初始的 256 的 bytes for i inrange(256): insert_vocabulary(vocabulary, bytes([i]))
这里,合并的先后顺序很重要,因为我们需要正确模拟出 training 过程中它是怎么被合并的。 不能使用贪心法进行合并! 例如,考虑单词 abcde,我们在 training 时先合并 b c 然后再合并 a b,这意味着 b c 的数量比 a b 多。如果使用贪心法合并,我们会得到 ab c d e,而正确的是 a bc d e。和正确的 tokenization 会有出入。
def_pretokenize(self, text: str) -> Iterable[List[str]]: data = ( regex.splititer( f'({"|".join(regex.escape(s) for s in self.special_tokens)})', text, ) ifself.special_tokens else [text] ) # split into sentences, and special_tokens
for each_sentence in data: if each_sentence inself.special_tokens: yield [each_sentence] else: yield regex.findall(self.pretoken_pattern, each_sentence)
然后我们需要把每一个 pretoken 转化成 a list of bytes,对 bytes 执行合并,最后转化成 a list of index,整段文本的 encoding 结果就是所有的 list of index 拼接起来.
def_apply_merge(self, token: bytes) -> List[int]: words = [bytes([i]) for i in token]
for merge inself.merges: iflen(words) == 1: break# no more merges possible
new_word = [] skip_next = False for i inrange(len(words)): if skip_next: skip_next = False continue
if ( i != len(words) - 1 and words[i] == merge[0] and words[i + 1] == merge[1] ): new_word.append(merge[0] + merge[1]) skip_next = True else: new_word.append(words[i]) words = new_word
return [self.inverse_vocab[word] for word in words]
def_convert_to_index(self, pre_tokens: Iterable[str]) -> List[int]: tokens_bytes = [tok.encode("utf-8") for tok in pre_tokens]
result: List[int] = [] for token_byte in tokens_bytes: if token_byte inself.special_tokens_bytes: result.append(self.inverse_vocab[token_byte]) continue