BPE Training 的流程和实现

CAUTION

本部分对应的是 Section 2 的 BPE.


Tokenization 是如何工作的呢?假定我们的文本放在 corpus.txt 文件里(这里的文本带有 special tokens 如 <|endoftext|>

通常,我们的文本文件很大,例如几个 GB,我们不可能把这些字符全都载入内存里(肯定爆炸),因此,第一步,我们需要对训练文本分块,以便发挥多线程的优势,让计算机并行处理多个任务。分块需要注意,每一块都必须是以 special token 结尾的,否则如果横跨了某个单词,那么这个单词被 tokenize 的结果很可能与 expected 的不同。

chunkify 实现

根据上面所说的,我们把文本进行分块。我们用 Binary IO 的方式打开文本,最后返回 List[int] 表示 chunk 的边界为 B[i-1] ~ B[i]

具体实现的话,我们用 file.tell() 的方式获取 bytes 数量后,直接先均分成 num 块(或者按内存大小计算块的大小,然后反过来计算块的数量)。然后,对每一块寻找下一个 special token 出现的位置,调整这一块的 boundary,这样,boundary 的每一个 int 都表示了 special token 的位置,也就满足了对 special token 出现位置的约束。最后排序去重就是最终的 boundary 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def chunkify(file: BinaryIO, specials: List[bytes], num_chunks: int) -> List[int]:
# get total bytes
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)

chunk_size = file_size // num_chunks
boundaries = [i * chunk_size for i in range(num_chunks + 1)]
boundaries[-1] = file_size

mini_chunk = 4096

for i in range(1, len(boundaries) - 1):
init_pos = boundaries[i]
file.seek(init_pos)

# 这里就是不断读取 4096 个字节,然后找有没有 special token
while True:
sub_chunk = file.read(mini_chunk)

if sub_chunk == b"":
boundaries[i] = file_size
break

special_pos = [
sub_chunk.find(token)
for token in specials
if sub_chunk.find(token) != -1
]
# 读进来的小 chunk 有 special token,那么直接截断
if len(special_pos) != 0:
special_pos = min(special_pos)
boundaries[i] = init_pos + special_pos
break
# 否则继续找 special token
init_pos += mini_chunk

return sorted(set(boundaries))

然后就要 tokenize 了。但是对于很大的语料库而言,会包含很多相同的单词(例如 the 这个单词可以在很多地方出现很多次),如果我们 naive 地对一长串 bytes 计算 byte-pair count,时间复杂度是很高的(bytes 实在太多了)。比如说……

1
2
3
4
>>> print(len('这是一段文字'))
6
>>> print(len('这是一段文字'.encode('utf-8')))
18

bytes 数量差了 2 倍!所以我们需要基于“很多单词是重复的,如果单词出现 aa 次,那么这个单词所构成的 byte-pairs 也至少出现 aa”这一观察,对 chunk of text 进行 Pre-Tokenization,得到词频统计。而词频统计字典的单词数通常比 byte string 长度短太多了。

对于英文文本来说,OpenAI 在 GPT2 里曾使用过 RegEx 分割单词(主要通过空格、句号等);对于中文、日文等不依赖空格的语言,基本上也有对应的库,如 jieba 等等。这里就只展示英文 pre-tokenization 的做法。Pre-Tokenization 也可以利用多线程并行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# file: 文件
# specials: 特殊 token,需要先按 special token 将段落分割成互不干扰的小段落。
# start, end: 因为是多线程进行 pre-tokenization,这里的 start, end 对应 chunkify 出来的一个 chunk
def count_token(file: BinaryIO, specials: List[str], start: int, end: int):
file.seek(start)
data = file.read(end - start).decode("utf-8", errors="ignore")

# Pretokenization, using split
# regex.escape 用来转义 <|endoftext|> 中的竖线,正则里的竖线表示“或者”
sentences = regex.split(
"|".join(regex.escape(special) for special in specials),
data,
)

# 正则分割出 pre-token 然后计数
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
tokens_count = defaultdict(int)
for sentence in sentences:
# 注意这里必须逐个句子处理
# 如果 "\n".join 的话,可能 join 的 "\n" 也会被当成 token 的一部分。
pre_tokens = regex.finditer(
PAT,
sentence,
)
for token in pre_tokens:
tokens_count[token.group(0)] += 1

assert not any(tok in tokens_count for tok in specials)
# assert '<|' not in tokens_count

return tokens_count

得到词频统计后,我们就可以 merge bytes 了。merge bytes 的过程是不断合并出现次数最多的 byte-pair,具体来说就是:

  1. 我们首先需要把单词表示成 a sequence of bytes,我们的字典保存的是 Dict[Tuple[bytes, ...], int],即词频(但是词是 tuple of bytes)
  2. 遍历词典,统计 byte-pair count
  3. 取出 count 最多的 byte-pair,如果有多个,则取 byte value 更大的那个
  4. merges 记录下这个要合并的 byte-pair,vocabulary 也新增一个条目记录这个 byte-pair(他们即将成为一个新的 token)
  5. 遍历词典,如果某个单词含有这个 byte-pair 则合并,在词典里更新其 tuple of bytes representation。

    如,我想合并 xy,而某个单词的 tuple of bytes 是 [a, b, c, x, y, d, f, e],那么合并后就变成了 [a, b, c, xy, d, f, e].

这里有一个小小的优化:显然,合并完一个 byte-pair 之后,只有“在这个单词里和这个 byte-pair 相交的其他 byte-pair 的数量会受到影响”。基于这一点,我们遍历单词的时候,同时检查和当前 byte-pair 相交的其他 byte-pair,然后减去单词的出现次数,并新增条目(受影响的 byte-pair 的一部分和新的 byte-pair 形成的 token)。这样就可以相对高效地进行 merge.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def bytepair(
tokens: defaultdict[str, int],
vocab_size: int,
init_vocab: defaultdict[int, bytes],
init_merge: List[Tuple[bytes, bytes]],
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
vocabulary = init_vocab
merges = init_merge

# > 这一步,我们先把所有 str 表示的单词转成 tuple of bytes
toks = defaultdict(int)
for token, count in tokens.items():
btoken = token.encode("utf-8")
toks[tuple(btoken[i : i + 1] for i in range(len(btoken)))] = count

# > 先插入初始的 256 的 bytes
for i in range(256):
insert_vocabulary(vocabulary, bytes([i]))

# > bp_cnt 的作用就是统计 byte-pair counting
bp_cnt = construct(toks)

for _ in tqdm( # > 这里加了一个进度条可视化
range(vocab_size - len(vocabulary)),
desc="merging byte token pairs",
total=vocab_size,
initial=len(vocabulary),
):
# > 提取出现次数最多的 byte-pair
count, pair = max([(cnt, tokpair) for tokpair, cnt in bp_cnt.items()])
newbyte = pair[0] + pair[1]
logger.debug(f"merging {pair[0]} and {pair[1]}")
# > 插入词汇表
insert_vocabulary(vocabulary, pair[0] + pair[1])
# > 记录 merge
merges.append(pair)
# > 当前的 pair 会被记录成一个 token,不算 byte-pair 了
# > 因此从 byte-pair counting 里删除
bp_cnt.pop(pair)

affected = []
for token in toks:
# > 如果单词不包含这个 byte-pair 那么直接跳过
if not contain(token, pair):
continue

new_token = []
# > 下面的循环是将 token 里的 byte-pair 合并起来
skip_next = False
for i in range(len(token)):
if skip_next:
skip_next = False
continue

if i < len(token) - 1 and (token[i], token[i + 1]) == pair:
new_token.append(pair[0] + pair[1])
# > 这里就是上面说的优化,只影响与 byte-pair 相交的 bytes
# > 两个 if 语句考虑的边界的情况
if i != 0:
bp_cnt[(token[i - 1], token[i])] -= toks[token]
bp_cnt[(token[i - 1], newbyte)] += toks[token]
if i + 1 != len(token) - 1:
bp_cnt[(token[i + 1], token[i + 2])] -= toks[token]
bp_cnt[(newbyte, token[i + 2])] += toks[token]
skip_next = True
else:
new_token.append(token[i])

new_token = tuple(new_token)
affected.append((token, toks[token], new_token))

# > 由于更新了单词的 bytes 表示,所以单词表也要同步更新
for old, cnt, new in affected:
toks.pop(old)
toks[new] = cnt
return (vocabulary, merges)
Serialization 注意事项

我们把 mergesvocabulary 写入文件时,如果直接写入,会遇到一个问题:

  1. 有一些 bytes 无法以 ASCII 的形式呈现,比如说 b'\x80'
  2. 有一些空白字符比如空格,如果直接写入文件,日后再读取的时候解析就会比较困难。merges.txt 有时

面对这些情况我们有一些处理方法:我们考虑把所有无法以 ASCII 呈现的字符映射到 256\ge 256 的字符上。可以使用 utility 工具 gpt2_bytes_to_unicode()。其原理就是先筛选出 <256\lt 256 里 printable 的字符,然后对于剩下的字符,映射到 256\ge 256 的字符上,返回一个字典。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def bytes2unicode_serializer() -> Dict[int, str]:
bytes_list = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
copy = bytes_list[:]

n = 0
for b in range(256):
if b in bytes_list:
continue
bytes_list.append(b)
copy.append(256 + n)
n += 1

d = dict(zip(bytes_list, [chr(x) for x in copy]))
return d

Tokenizer Encode/Decode 的流程和实现

regex.split() 的使用

regex.split(REGEX, STRING) 可以按 REGEX 分割字符串,但是默认不会保留匹配了 REGEX 的部分。可以通过在外面加一个圆括号 (REGEX)regex 能够保留匹配的部分。

Encoding

Tokenizer Encoding 的过程是接受一个字符串 str,然后输出 List[int]

我们现在已有的是

What We Have Type Meaning
vocabulary Dict[int, bytes] 记录了 bytes 对应的编码
merges List[Tuple[bytes, bytes]] 记录了 bytes 合并的先后顺序

这里,合并的先后顺序很重要,因为我们需要正确模拟出 training 过程中它是怎么被合并的。 不能使用贪心法进行合并! 例如,考虑单词 abcde,我们在 training 时先合并 b c 然后再合并 a b,这意味着 b c 的数量比 a b 多。如果使用贪心法合并,我们会得到 ab c d e,而正确的是 a bc d e。和正确的 tokenization 会有出入。