BPE
- 一、 BPETrain
- 1、 unicode standard and unicode encoding
- 2、 子詞分詞(subword tokenization)
- 3、 BPE的訓練
- a、 Vocabulary initialization
- b、 Pre-tokenization
- c、 Compute BPE merges
- 4、 train_BPE更多實現上的細節
- 二、 BPETokenizer
- init函數
- from_files
- encode
- decode
- encode_iterable
- 三、 如何測試
- 四、 github完整代碼
- 五、 總結
一、 BPETrain
1、 unicode standard and unicode encoding
unicode標準是一個字符集,它將字符對應到一個整數(被稱為碼點 code point
),unicode standard 16包含154,998個字符,涵蓋168種語言。
unicode encoding是一種編碼方式,它將unicode字符對應到一個字節序列,Unicode standard定義了三種encoding 方式分別為utf-8、utf-16、utf-32,其中utf-8目前最為常用。
三者的對比如下:
特性 | UTF-8 | UTF-16 | UTF-32 |
---|---|---|---|
字符長度 | 1-4 字節 | 2 或 4 字節 | 固定 4 字節 |
ASCII 兼容性 | 完全兼容 | 不兼容 | 不兼容 |
存儲效率 | 高效(尤其是英文文本) | 中等 | 低效 |
處理復雜度 | 處理非 ASCII 字符稍復雜 | 需要處理代理對 | 簡單(固定寬度) |
適用場景 | 互聯網、文件系統、數據庫 | 操作系統、編程語言 | 特定需求 |
關于這部分內容可以看一下Python中二進制文件操作,了解一下python操作二進制字節文件的內容。
2、 子詞分詞(subword tokenization)
字詞分詞是一種介于byte-level tokenization和word-level tokenization的分詞技術,兩者的對比如下:
對比維度 | Word-level Tokenization(詞級分詞) | Byte-level Tokenization(字節級分詞) |
---|---|---|
分詞單位 | 以“詞”為最小單位(如 “你好”, “apple”, “the”) | 以“字節”為最小單位(如 b’h’, b’\xe4’) |
粒度 | 最粗粒度 | 最細粒度 |
詞匯量 | 通常幾千到幾十萬個詞 | 固定為 256 個基礎 token(0~255) |
是否支持未知詞 | 無法處理未登錄詞(OOV) | 完全支持所有字符(無 OOV) |
是否語言相關 | 通常需要語言特定的詞典或規則 | 語言無關,適用于任何語言 |
輸入長度 | 較短 | 較長(每個字符可能拆成多個字節) |
word-level無法解決oov問題,同時詞表長度太大,但是輸入長度短,而byte-level可以結局oov問題,同時詞表較短,但是輸入長度太長,對于現在的LLM,輸入長度過長會帶來更大計算量,同時會有長距離依賴問題
,為了trade-off兩者,subword
是一種很好的解決辦法。
subword-level
的思想很簡單,就是將byte sequence中出現頻次高的內容作為一個詞表中新的entry。
關于如何選擇subword加入詞表,可以使用1994年Gage提出的BPE算法(Byte pair encoding)
3、 BPE的訓練
bpe的訓練過程主要分為三步:
- 詞表初始化(Vocabulary initialization)
- 預分詞(Pre-tokenization)
- 合并(Compute BPE merges)
a、 Vocabulary initialization
詞表初始化(Vocabulary initialization):由于訓練的是byte-level的BPE初始詞表的大小應該是256。同時需要將,文本中會有一些special_tokens
,這些special_tokens
是不參與bpe的訓練的,直接將這些內容加入到初始詞表中。
## initialize vocabulary step
def initialize_vocabulary(special_tokens: list[str]
) -> dict[int, bytes]:vocabulary = {}vocabulary.update({i: special_tokens[i].encode("utf-8") for i in range(0, len(special_tokens))})vocabulary.update({i + len(vocabulary): bytes([i]) for i in range(256)}) return vocabulary
b、 Pre-tokenization
預分詞(Pre-tokenization):如果直接開始進行merge,那么每次都需要遍歷整個數據集的文本進行merge,這是一項耗時的操作,同時可能會導致dog!
、 dog.
這兩個詞僅僅因為標點符號不一樣就成為兩個完全不同的subword被分配不同的id,盡管這兩個詞在語義上高度相似,它們也被認為是兩個完全不同的詞。pre-tokenization就是為了解決上面的問題,pre-tokenization可以被看作是一種粗粒度
的tokenization,例如text是一個pre-token,同時text在全文中出現了10詞,就不再需要看(t,e)pair
在全文中出現了多少此,而是直接給(t,e)pair
增加10。
這里我實現的pre_tokenization是直接返回的dict[tuple[bytes], int]
,也就是返回的每個tuple[bytes]
出現的次數。
舉個例子這里輸入為"Hello word<|endoftext|>Hello "
,special_tokens=["<|endoftext|>"]
得到的結果就是
{(b'H', b'e', b'l', b'l', b'o'): 2,(b'w', b'o', b'r', b'l', b'd'): 1
}
這里的pre_tokenization會被拋棄,同時也沒有保留原來的順序,這樣實現其實不好,對于后面使用tokenizer進行encode是不方便的,還需要重新實現,其實可以直接保留special_tokens
, 同時保留原來的順序,使用list可以滿足這樣的要求,對于每個tuple[bytes]
出現的頻率統計可以放到下一步merge中去做。
## pre_tokenization step
def pre_tokenization(input: str, special_tokens: list[str]
) -> dict[tuple[bytes], int]:escaped_tokens = [re.escape(tok) for tok in special_tokens]split_pattern = "|".join(escaped_tokens) # 按special_tokens分割inputmatch_pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # 分割后匹配除去special_tokens中的wordsplit_texts = re.split(split_pattern, input) # 得到分割后的文本,格式為listpre_tokens = {}for split_text in split_texts:for word in match_pattern.finditer(split_text):word_str = word.group(0).encode("utf-8")bytes_word_tuple = tuple(bytes([word]) for word in word_str)pre_tokens[bytes_word_tuple] = pre_tokens.get(bytes_word_tuple, 0) + 1 return pre_tokens
c、 Compute BPE merges
合并(Compute BPE merges):迭代合并時不考慮夸pre-token的情況,同時當有多個byte pair頻率相同時,選擇字典序更大的byte pair進行合并。
出了初始詞表中的token、和BPE算法合并產生的token,還有一些special token,這些token有的用來表示元數據具有一些特殊的作用,也應該被加入到詞表中。
get_pair_freq
:由于pre_tokenization步驟得到的結果是如下格式:
{(b'H', b'e', b'l', b'l', b'o'): 2,(b'w', b'o', b'r', b'l', b'd'): 1
}
而merge需要相鄰bytes pair出現頻率最高的那一對然后合并,所以get_pair_freq
的作用就是統計相鄰bytes pair的頻率:
def get_pair_freq(word_counts: Counter[tuple[bytes]]
) -> Counter[tuple[bytes]]:freq_pair: Counter[tuple[bytes]] = {}for word, cnt in word_counts.items():for i in range(len(word) - 1):pair = (word[i], word[i + 1])freq_pair[pair] = freq_pair.get(pair, 0) + cntreturn freq_pair
find_pair
是為了獲得出現頻率最高的pair,當有多個這樣的pair時會返回字典序最大的。
## merge_tools
def find_pair(freq_pair: Counter[tuple[bytes]]
) -> tuple[bytes]:max_value = max(freq_pair.values())max_pair = max([k for k, v in freq_pair.items() if v == max_value])return max_pair
對于pre_tokenization得到的數據是一個list[tuple[bytes]]
,分開處理每一個tuple[bytes]
{(b'H', b'e', b'l', b'l', b'o'): 2,(b'w', b'o', b'r', b'l', b'd'): 1
}
也就是說處理對于上面的數據分別處理
(b'H', b'e', b'l', b'l', b'o')
(b'w', b'o', b'r', b'l', b'd')
get_merged_word
這個函數就是對每一個tuple[bytes]
進行merge,然后返回merge后得到的新tuple[bytes]
。
## merge_tools
def get_merged_word(word: tuple[bytes], cmp_pair: tuple[bytes]
) -> tuple[bytes]:new_word = [] # 存儲merge后的wordlength, cur = len(word), 0while cur < length:if cur + 1 < length: # 當還能組成的pair時if (word[cur], word[cur + 1]) == cmp_pair: # 找到了可以merge的對象new_word.append(word[cur] + word[cur + 1])cur += 2else:new_word.append(word[cur])cur += 1 else:new_word.append(word[cur])cur += 1return tuple(new_word)
4、 train_BPE更多實現上的細節
由于pre_token非常耗時,所以采用多進程并行處理,如何進行多進程并行處理?
首先是將數據集進行chunk
,具體的chunk規則可以參考assignment1-basics/cs336_basics/pretokenization_example.py中的代碼,find_chunk_boundaries
這個函數將輸入chunk成幾個完整的內容,他并不是單一嚴格按字節分割,而是會在字節后面的第一個special_token
位置進行分割。分割的方式如圖所示:
然后分別對每個分割得到chunking進行多進程并行pre-token
,多進程可以使用python的內置模塊multiprocessing
,如果不了解可以參考Python多進程并行multiprocess基礎。
下面是多進程訓練的代碼:merge_pre_tokens用于將得到的多個pre_tokens字典合并為一個字典。
def merge_pre_tokens(dicts: list[Counter[tuple[bytes]]]
) -> Counter[tuple[bytes]]:merged_counter = Counter()for counter in dicts:merged_counter.update(counter)return merged_counter## 多進程進行pre_tokenization
def parallel_pre_tokenization(file_path: str, special_tokens: list[str], num_workers: int = None
) -> Counter[tuple[bytes]]:params = []with open(file_path, 'rb') as f:boundary = find_chunk_boundaries(f, num_workers, special_tokens[0].encode("utf-8")) for left, right in zip(boundary[:-1], boundary[1:]):f.seek(left)chunk = f.read(right - left).decode("utf-8", errors="ignore")params.append((chunk, special_tokens))with Pool(processes=num_workers) as pool:result_dicts = pool.starmap(pre_tokenization, params)return merge_pre_tokens(result_dicts)
最后可以優化merge的過程,由于merge的過程會每次都去遍歷pre_tokens,然后統計byte-pair的出現次數,最后找到byte-pair的最大值作為本次merge的byte-pair。這個過程需要遍歷所有的tokens,可以采用一種增量遍歷的方式。
預處理一個全局byte-pair出現的頻次表格式如下:
freq = {(b'a', b'b'): 3,(b'a', b'c'): 2,(b'a', b'd'): 10,(b'a', b'e'): 11,(b'ad', b'e'): 13
}
本次更新選中了(b'a', b'd')
作為best_pair,找出來含有best_pair的token,然后對于一個滿足的wordA
,先全局byte-pair中把A
的所有byte-pair減去,然后加上新生成的word產生的pair。
將上面的三個步驟集成起來就得到下面的訓練函數
def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:## setp1 initinalize vocabularyvocabulary: dict[int, bytes] = initialize_vocabulary(special_tokens)## setp2 pre tokenization# file_path = "assignment1-basics/data/TinyStoriesV2-GPT4-train.txt"word_counts = parallel_pre_tokenization(input_path,special_tokens,16)cur_id: int = len(vocabulary)merges: list[tuple[bytes, bytes]] = []## step3 BPE mergeneed_merge_cnt: int = vocab_size - cur_idpair_freqs = get_pair_freq(word_counts)for i in tqdm(range(need_merge_cnt)): # 迭代merge頻次最高的byte-pairif not pair_freqs:breakbest_pair = find_pair(pair_freqs)merges.append(best_pair)vocabulary[cur_id] = best_pair[0] + best_pair[1]cur_id += 1# 找出所有需要更新的wordwords_need_update = {}for word, cnt in word_counts.items():if best_pair[0] in word and best_pair[1] in word:for i in range(len(word) - 1):if (word[i], word[i + 1]) == best_pair:words_need_update[word] = cntbreak# 更新word_countsfor word, cnt in words_need_update.items():# 增量更新pair頻率表for i in range(len(word) - 1):pair = (word[i], word[i + 1])pair_freqs[pair] = pair_freqs.get(pair, 0) - cntdel word_counts[word]new_word = get_merged_word(word, best_pair)word_counts[new_word] = word_counts.get(new_word, 0) + cntfor i in range(len(new_word) - 1):pair = (new_word[i], new_word[i + 1])pair_freqs[pair] = pair_freqs.get(pair, 0) + cntreturn vocabulary, merges
二、 BPETokenizer
cs336的文檔中已經說明了BPETokenizer類中必須實現的接口。
init函數
即初始化tokenizer,這里的vocab,merges,special_tokens都和上面訓練時的格式類型一致。
def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], special_tokens: list[str] | None = None): self.vocab = vocabself.merges = mergesself.special_tokens = special_tokens
from_files
文檔中要求實現一個可以從路徑中加載vocab和merges的功能。這里我是仿照他給的pytest測試里assignment1-basics/tests/test_tokenizer.py的get_tokenizer_from_vocab_merges_path
寫的,里面的bytes_to_unicode
函數就是將256個字節都能可視化顯示,因為有很多控制字符space什么的是沒法,顯示的,這里是因為它測試讀取的vocab、merges保存的格式是這樣的所以讀取的時候還要將它保存的格式轉換為0~255的bytes。測試用的vocab、merges在assignment1-basics/tests/fixtures/gpt2_vocab.json和assignment1-basics/tests/fixtures/gpt2_merges.text。
@classmethoddef from_files(cls, vocab_filepath: str, merges_filepath: str, special_tokens: list[str] | None = None) -> BPETokenizer:@lru_cachedef bytes_to_unicode() -> dict[int, str]:bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("?"), ord("?") + 1)) + list(range(ord("?"), ord("?") + 1))cs = bs[:]n = 0for b in range(2**8):if b not in bs:bs.append(b)cs.append(2**8 + n)n += 1characters = [chr(n) for n in cs]d = dict(zip(bs, characters))return ddef bytes_to_str(b: bytes) -> str:byte_to_uni = bytes_to_unicode()s = ""for bit in b:s.join(byte_to_uni[bit])return sdef str_to_bytes(s: str) -> bytes:byte_to_uni = bytes_to_unicode()byte_decoder = {v: k for k, v in byte_to_uni.items()}ans = bytearray()for c in s:ans.extend([byte_decoder[c]])return bytes(ans)# 處理vocabtry:with open(vocab_filepath, "r", encoding="utf-8") as f:vocab_ = json.load(f)except Exception as e:raise RuntimeError(f"Error loading vocabulary from {vocab_filepath}: {e}")vocab = {v: str_to_bytes(k) for k, v in vocab_.items()}# 處理mergesmerges_ = []with open(merges_filepath, 'r', encoding="utf-8") as f:for line in f:clean_line = line.strip()if clean_line and len(clean_line.split(" ")) == 2:merges_.append(tuple(clean_line.split(" ")))if special_tokens:for special_token in special_tokens:byte_encoded_special_token = special_token.encode("utf-8")if byte_encoded_special_token not in set(vocab.values()):vocab[len(vocab)] = byte_encoded_special_tokenmerges = [(str_to_bytes(str1), str_to_bytes(str2),)for str1, str2 in merges_]return cls(vocab, merges, special_tokens)
encode
當我們訓練好了一個BPEtokenzier后,就可以通過得到vocab
和一個merge
對輸入的文本進行tokenization。
這里encode的步驟分為三步,第一步首先進行pre-tokenization,然后進行merge,最后在詞表中查看每個詞元對應的id。
首先是pre-tokenization,訓練時的pre-tokenization是先按special_tokens
進行split,然后將special_tokens
丟棄,然后再按gpt2的pat模式去分割。在使用時,不能舍棄special_tokens
,同時需要保留每個詞的順序。
十分需要注意并小心的corner case就是special_tokens為None的情況,當special_tokens為None時,不都對special_tokens使用sorted,同時第一步按special_tokens分割,結果應該是[text],list包裹原始文本。
def pre_tokenization(self,text: str, ) -> list[tuple[bytes]]:special_tokens = sorted(self.special_tokens, key=lambda x: -len(x)) if self.special_tokens is not None else []escaped_tokens = [re.escape(tok) for tok in special_tokens] if special_tokens else []split_pattern = "(" + "|".join(escaped_tokens) + ")" # 按special_tokens分割inputmatch_pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")split_texts = re.split(split_pattern, text) if len(escaped_tokens) != 0 else [text]# 得到分割后的文本,格式為listpre_tokens = []for split_text in split_texts: if self.special_tokens != None and split_text in self.special_tokens:pre_tokens.append((split_text.encode('utf-8'),))else:for word in match_pattern.finditer(split_text):word_str = word.group(0).encode("utf-8")bytes_word_tuple = tuple(bytes([word]) for word in word_str)pre_tokens.append(bytes_word_tuple)return pre_tokens
其次是merge,這里merge要按保存的byte-pair
的順序去merge,因為本身文本訓練的時候merge就是按順序的,最開始寫這個merge的時候沒有按順序debug了好久。
這里merge這一步我實現了兩個函數,首先是merge函數,這個merge函數用于pre_tokenization得到的單個的tokens就是tuple[bytes]
的merge,tokens的形狀類似是這樣的(b'h', b'e', b'l', b'l', b'o')
,然后返回的結果的類型是tuple[bytes]
,對于(b'h', b'e', b'l', b'l', b'o')
這個例子,返回的結果可能是(b'he', b'll', b'o')
。
def merge(self,pre_token: tuple[bytes],ranked_merges: dict[bytes, int]) -> tuple[bytes]: while True:cur_min_rank = len(ranked_merges)best_pair = Nonefor i in range(len(pre_token) - 1):pair = pre_token[i] + pre_token[i + 1]rk = ranked_merges.get(pair, float('inf'))if rk < cur_min_rank:cur_min_rank = rkbest_pair = pairif best_pair is None:breaknew_token = []i = 0while i < len(pre_token):if i + 1 < len(pre_token) and pre_token[i] + pre_token[i + 1] == best_pair:new_token.append(best_pair)i += 2 else:new_token.append(pre_token[i])i += 1pre_token = new_tokenreturn pre_token
merge_pre_tokens是將pre_tokenization得到的pre_tokens列表里面的每個pre_tokens都應用上面的merge合并,得到的結果就是最終的合并后的tokens。
這里的小細節需要注意的是對于special_tokens,其不需要merge,也就是說我們在遍歷的時候遇到了
(b'<|endoftext|>, )'
的時候直接將他append進我們的結果列表中即可
def merge_pre_tokens(self,pre_tokens: list[tuple[bytes]],) -> list[tuple[bytes]]:merged_tokens: list[tuple[bytes]]= []special_tokens_bytes = ([tuple(special_token.encode('utf-8')) for special_token in self.special_tokens]if self.special_tokens else [])ranked_merges = {bytes1 + bytes2: idx for idx, (bytes1, bytes2) in enumerate(self.merges)}for pre_token in pre_tokens:if pre_token in special_tokens_bytes:merged_tokens.append(pre_token)else:merged_tokens.append(self.merge(pre_token, ranked_merges))return merged_tokens
最后實現文檔要求的接口encode,集成上面的功能,在vocab中查找每個bytes
對應的id
進行替換,然后返回id的列表即可。
def encode(self, text: str) -> list[int]:token_to_id = {token: id for id, token in self.vocab.items()}tokens = []pre_tokens = self.pre_tokenization(text)merged_tokens = self.merge_pre_tokens(pre_tokens)joined_tokens = []for word in merged_tokens:for b in word:joined_tokens.append(b)return [token_to_id.get(token, -1) for token in joined_tokens]
decode
decode函數就很簡單了,查找詞表vocab
,將每個token_id還原回bytes,然后進行拼接,再按utf-8
的格式decode成str即可。 errors="replace"
這個參數的作用實現的是文檔里面黃色的部分,即可能decode的輸入token_ids并非是配套的encode得到的,就可能有不合法的部分。無法用unicode解碼的就用U+FFFD替換。
def decode(self, ids: list[int]) -> str:joined_bytes = bytearray()for id in ids:joined_bytes.extend(self.vocab[id])return bytes(joined_bytes).decode("utf-8", errors="replace")
encode_iterable
當需要encode比較大的文件時,可能無法將本文全部加載進內存,這時就需要流式讀取,一部分一部分進行encode。已經實現了上面的encode這個功能就比較簡單了,就是調用一下encode,然后使用python中的yield from
即可。關于yield from
的用法參考Python中yield和yield from
def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:for chunk in iterable:if not chunk:continuetoken_ids = self.encode(chunk)yield from token_ids
三、 如何測試
這門課程可以使用pytest在本地進行測試,先進入assignment1-basic/tests
中。
要測試train_bpe部分的內容就運行
pytest train_bpe.py
上面的pytest命令會運行train_bpe.py中以**test_**開頭的所有測試函數。
要測試tokenizer部分的內容就運行:
pytest test_tokenizer.py
當然對于具體哪個測試沒過可以看一下測試代碼,也方便debug。
關于pytest的應用可以參考這個鏈接coming soon(還沒寫,寫好再發)
然后你就可以順利通過測試了~~,需要注意的是test_tokenizer的最后一個測試出現XFailed沒有關系,可以進到test中看一下那個函數,里面有說明。
四、 github完整代碼
github倉庫鏈接cs336 assignment1 BPETokenizer
五、 總結
關于這個BPE Tokenizer的細節確實很多,實現的時候也學到了很多東西。