BPE Training 的流程和实现
Tokenization 是如何工作的呢?假定我们的文本放在 corpus.txt
文件里(这里的文本带有 special tokens 如 <|endoftext|>
)
通常,我们的文本文件很大,例如几个 GB,我们不可能把这些字符全都载入内存里(肯定爆炸),因此,第一步,我们需要对训练文本分块 ,以便发挥多线程的优势,让计算机并行处理多个任务。分块需要注意,每一块都必须是以 special token 结尾的,否则如果横跨了某个单词,那么这个单词被 tokenize 的结果很可能与 expected 的不同。
chunkify
实现
根据上面所说的,我们把文本进行分块。我们用 Binary IO 的方式打开文本 ,最后返回 List[int]
表示 chunk 的边界为 B[i-1] ~ B[i]
具体实现的话,我们用 file.tell()
的方式获取 bytes 数量后,直接先均分成 num
块(或者按内存大小计算块的大小,然后反过来计算块的数量)。然后,对每一块寻找下一个 special token 出现的位置,调整这一块的 boundary,这样,boundary 的每一个 int
都表示了 special token 的位置,也就满足了对 special token 出现位置的约束。最后排序去重就是最终的 boundary 了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 def chunkify (file: BinaryIO, specials: List [bytes ], num_chunks: int ) -> List [int ]: file.seek(0 , os.SEEK_END) file_size = file.tell() file.seek(0 ) chunk_size = file_size // num_chunks boundaries = [i * chunk_size for i in range (num_chunks + 1 )] boundaries[-1 ] = file_size mini_chunk = 4096 for i in range (1 , len (boundaries) - 1 ): init_pos = boundaries[i] file.seek(init_pos) while True : sub_chunk = file.read(mini_chunk) if sub_chunk == b"" : boundaries[i] = file_size break special_pos = [ sub_chunk.find(token) for token in specials if sub_chunk.find(token) != -1 ] if len (special_pos) != 0 : special_pos = min (special_pos) boundaries[i] = init_pos + special_pos break init_pos += mini_chunk return sorted (set (boundaries))
然后就要 tokenize 了。但是对于很大的语料库而言,会包含很多相同的单词(例如 the 这个单词可以在很多地方出现很多次),如果我们 naive 地对一长串 bytes 计算 byte-pair count,时间复杂度是很高的(bytes 实在太多了)。比如说……
1 2 3 4 >>> print (len ('这是一段文字' ))6 >>> print (len ('这是一段文字' .encode('utf-8' )))18
bytes 数量差了 2 倍!所以我们需要基于“很多单词是重复的,如果单词出现 a a a 次,那么这个单词所构成的 byte-pairs 也至少出现 a a a 次 ”这一观察,对 chunk of text 进行 Pre-Tokenization ,得到词频统计。而词频统计字典的单词数通常比 byte string 长度短太多了。
对于英文文本来说,OpenAI 在 GPT2 里曾使用过 RegEx 分割单词(主要通过空格、句号等);对于中文、日文等不依赖空格的语言,基本上也有对应的库,如 jieba
等等。这里就只展示英文 pre-tokenization 的做法。Pre-Tokenization 也可以利用多线程并行 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def count_token (file: BinaryIO, specials: List [str ], start: int , end: int ): file.seek(start) data = file.read(end - start).decode("utf-8" , errors="ignore" ) sentences = regex.split( "|" .join(regex.escape(special) for special in specials), data, ) PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" tokens_count = defaultdict(int ) for sentence in sentences: pre_tokens = regex.finditer( PAT, sentence, ) for token in pre_tokens: tokens_count[token.group(0 )] += 1 assert not any (tok in tokens_count for tok in specials) return tokens_count
得到词频统计后,我们就可以 merge bytes 了。merge bytes 的过程是不断合并出现次数最多的 byte-pair,具体来说就是:
我们首先需要把单词表示成 a sequence of bytes,我们的字典保存的是 Dict[Tuple[bytes, ...], int]
,即词频(但是词是 tuple of bytes)
遍历词典,统计 byte-pair count
取出 count 最多的 byte-pair,如果有多个,则取 byte value 更大的那个
merges
记录下这个要合并的 byte-pair,vocabulary
也新增一个条目记录这个 byte-pair(他们即将成为一个新的 token)
遍历词典,如果某个单词含有这个 byte-pair 则合并,在词典里更新其 tuple of bytes representation。
如,我想合并 x
和 y
,而某个单词的 tuple of bytes 是 [a, b, c, x, y, d, f, e]
,那么合并后就变成了 [a, b, c, xy, d, f, e]
.
这里有一个小小的优化:显然,合并完一个 byte-pair 之后,只有“在这个单词里和这个 byte-pair 相交的其他 byte-pair 的数量会受到影响”。基于这一点,我们遍历单词的时候,同时检查和当前 byte-pair 相交的其他 byte-pair,然后减去单词的出现次数,并新增条目(受影响的 byte-pair 的一部分和新的 byte-pair 形成的 token)。这样就可以相对高效地进行 merge.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 def bytepair ( tokens: defaultdict[str , int ], vocab_size: int , init_vocab: defaultdict[int , bytes ], init_merge: List [Tuple [bytes , bytes ]], ) -> Tuple [Dict [int , bytes ], List [Tuple [bytes , bytes ]]]: vocabulary = init_vocab merges = init_merge toks = defaultdict(int ) for token, count in tokens.items(): btoken = token.encode("utf-8" ) toks[tuple (btoken[i : i + 1 ] for i in range (len (btoken)))] = count for i in range (256 ): insert_vocabulary(vocabulary, bytes ([i])) bp_cnt = construct(toks) for _ in tqdm( range (vocab_size - len (vocabulary)), desc="merging byte token pairs" , total=vocab_size, initial=len (vocabulary), ): count, pair = max ([(cnt, tokpair) for tokpair, cnt in bp_cnt.items()]) newbyte = pair[0 ] + pair[1 ] logger.debug(f"merging {pair[0 ]} and {pair[1 ]} " ) insert_vocabulary(vocabulary, pair[0 ] + pair[1 ]) merges.append(pair) bp_cnt.pop(pair) affected = [] for token in toks: if not contain(token, pair): continue new_token = [] skip_next = False for i in range (len (token)): if skip_next: skip_next = False continue if i < len (token) - 1 and (token[i], token[i + 1 ]) == pair: new_token.append(pair[0 ] + pair[1 ]) if i != 0 : bp_cnt[(token[i - 1 ], token[i])] -= toks[token] bp_cnt[(token[i - 1 ], newbyte)] += toks[token] if i + 1 != len (token) - 1 : bp_cnt[(token[i + 1 ], token[i + 2 ])] -= toks[token] bp_cnt[(newbyte, token[i + 2 ])] += toks[token] skip_next = True else : new_token.append(token[i]) new_token = tuple (new_token) affected.append((token, toks[token], new_token)) for old, cnt, new in affected: toks.pop(old) toks[new] = cnt return (vocabulary, merges)
Serialization 注意事项
我们把 merges
和 vocabulary
写入文件时,如果直接写入,会遇到一个问题:
有一些 bytes 无法以 ASCII 的形式呈现,比如说 b'\x80'
有一些空白字符比如空格,如果直接写入文件,日后再读取的时候解析就会比较困难。merges.txt
有时
面对这些情况我们有一些处理方法:我们考虑把所有无法以 ASCII 呈现的字符映射到 ≥ 256 \ge 256 ≥ 256 的字符上。可以使用 utility 工具 gpt2_bytes_to_unicode()
。其原理就是先筛选出 < 256 \lt 256 < 256 里 printable 的字符,然后对于剩下的字符,映射到 ≥ 256 \ge 256 ≥ 256 的字符上,返回一个字典。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def bytes2unicode_serializer () -> Dict [int , str ]: bytes_list = ( list (range (ord ("!" ), ord ("~" ) + 1 )) + list (range (ord ("¡" ), ord ("¬" ) + 1 )) + list (range (ord ("®" ), ord ("ÿ" ) + 1 )) ) copy = bytes_list[:] n = 0 for b in range (256 ): if b in bytes_list: continue bytes_list.append(b) copy.append(256 + n) n += 1 d = dict (zip (bytes_list, [chr (x) for x in copy])) return d
Tokenizer Encode/Decode 的流程和实现
regex.split()
的使用
regex.split(REGEX, STRING)
可以按 REGEX
分割字符串,但是默认不会保留匹配了 REGEX
的部分。可以通过在外面加一个圆括号 (REGEX)
让 regex
能够保留匹配的部分。
Encoding
Tokenizer Encoding 的过程是接受一个字符串 str
,然后输出 List[int]
。
我们现在已有的是
What We Have
Type
Meaning
vocabulary
Dict[int, bytes]
记录了 bytes 对应的编码
merges
List[Tuple[bytes, bytes]]
记录了 bytes 合并的先后顺序
这里,合并的先后 顺序很重要,因为我们需要正确模拟出 training 过程中它是怎么被合并的。 不能使用贪心法进行合并! 例如,考虑单词 abcde
,我们在 training 时先合并 b c
然后再合并 a b
,这意味着 b c
的数量比 a b
多。如果使用贪心法合并,我们会得到 ab c d e
,而正确的是 a bc d e
。和正确的 tokenization 会有出入。