本文將介紹以下內容:
- 1. BPE 算法核心原理
- 2. BPE 算法流程
- 3. BPE 算法源碼實現Demo
BPE最早是一種數據壓縮算法,由Sennrich等人于2015年引入到NLP領域并很快得到推廣。該算法簡單有效,因而目前它是最流行的方法。GPT-2和RoBERTa使用的Subword算法都是BPE。
1. BPE 算法核心原理:
它的主要思想是:
- 使用頻率統計來逐步合并高頻的字符/子詞對。
- 從最小的單位(字符)開始,逐漸學習得到一套子詞詞表,使模型能夠兼顧 常見詞的完整表示 和 罕見詞的組合表示。
在大語言模型時代,最常用的分詞方法是Byte-Pair Encoding(BPE)和Byte-level BPE(BBPE)。該算法的核心思想是逐步合并出頻率最高的子詞對而不是像wordpiece一樣通過計算合并分數。
2. BPE 算法流程:
(1)計算初始詞表:通過訓練語料獲得或者最初的英文種26個字母加上各種符號以及常見中文字符,這些作為初始詞表。
(2)構建頻率統計:統計所有子詞單元對在文本中的出現頻率。
(3)合并頻率最高的子詞對:選擇出現頻率最高的子詞對,將它們合并成一個新的子詞單元,并更新詞匯表。
(4)重復合并步驟:不斷重復步驟2和步驟3,直到達到預定的詞匯表大小、合并次數。
(5)分詞:使用訓練得到的詞匯表對文本進行分詞。
3. 算法源碼實現Demo
import re
from collections import defaultdictclass BPE:def __init__(self, vocab_size=100):self.vocab_size = vocab_sizeself.vocab = {} # word -> frequencyself.merges = [] # list of mergesself.bpe_ranks = {} # pair -> rank# ---------- 構建初始詞表 ----------def build_vocab(self, corpus):"""corpus: list[str],輸入語料英文: 用空格分詞中文: 逐字處理"""vocab = defaultdict(int)for line in corpus:words = line.strip().split()for word in words:chars = list(word) + ["</w>"] # 加上詞邊界vocab[tuple(chars)] += 1self.vocab = dict(vocab)# ---------- 統計 pair ----------def get_stats(self):"""統計 pair 的頻率"""pairs = defaultdict(int)for word, freq in self.vocab.items():for i in range(len(word)-1):pairs[(word[i], word[i+1])] += freqreturn pairs# ---------- 合并 ----------def merge_vocab(self, pair):"""執行一次合并"""new_vocab = {}bigram = re.escape(" ".join(pair))pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')for word, freq in self.vocab.items():word_str = " ".join(word)new_word = tuple(pattern.sub("".join(pair), word_str).split())new_vocab[new_word] = freqself.vocab = new_vocab# ---------- 訓練 ----------def train(self, save_merges="merges.txt", save_vocab="vocab.txt"):# 初始 alphabet 大小alphabet = set(ch for word in self.vocab for ch in word)num_merges = self.vocab_size - len(alphabet)print(f"初始alphabet大小={len(alphabet)},目標vocab_size={self.vocab_size},合并次數≈{num_merges}")for i in range(num_merges):pairs = self.get_stats()if not pairs:breakbest = max(pairs, key=pairs.get)self.merges.append(best)self.merge_vocab(best)# 構建 bpe_ranksself.bpe_ranks = dict(zip(self.merges, range(len(self.merges))))print(f"self.bpe_ranks:{self.bpe_ranks}")# 保存 mergeswith open(save_merges, "w", encoding="utf-8") as f:for a, b in self.merges:f.write(f"{a} {b}\n")# 保存 vocabvocab_tokens = set()for word in self.vocab:for token in word:vocab_tokens.add(token)with open(save_vocab, "w", encoding="utf-8") as f:for token in sorted(vocab_tokens):f.write(token + "\n")print(f"? merges 保存到 {save_merges}, vocab 保存到 {save_vocab}")# ---------- 推理 ----------def get_pairs(self, word):"""獲取當前詞的所有pair"""pairs = set()prev_char = word[0]for char in word[1:]:pairs.add((prev_char, char))prev_char = charreturn pairsdef encode_word(self, word):"""BPE 編碼單個詞"""word = tuple(list(word) + ["</w>"])pairs = self.get_pairs(word)if not pairs:return [word]while True:# 找到rank最小的pairbigram = min(pairs, key=lambda p: self.bpe_ranks.get(p, float("inf")))if bigram not in self.bpe_ranks:breaknew_word = []i = 0while i < len(word):if i < len(word)-1 and word[i] == bigram[0] and word[i+1] == bigram[1]:new_word.append(word[i] + word[i+1])i += 2else:new_word.append(word[i])i += 1word = tuple(new_word)if len(word) == 1:breakpairs = self.get_pairs(word)return list(word)def decode_word(self, tokens):"""還原單詞"""word = "".join(tokens)if word.endswith("</w>"):word = word[:-4]return worddef encode_sentence(self, sentence):"""BPE 編碼整句"""return [self.encode_word(w) for w in sentence.strip().split()]def decode_sentence(self, tokens_list):"""解碼整句"""return " ".join(self.decode_word(toks) for toks in tokens_list)# ================== 示例 ==================
if __name__ == "__main__":corpus = ["deep learning is the future of ai","see my eyes first","see my dogs","you are the best","you are the fast","machine learning can be applied to natural language processing","深度學習是人工智能的未來","機器學習可以應用于自然語言處理","人工智能改變世界","學習深度神經網絡在圖像識別中表現優秀"]# 訓練bpe = BPE(vocab_size=100)bpe.build_vocab(corpus)bpe.train("merges.txt", "vocab.txt")# 測試推理print("\n=== 單詞測試 ===")for w in ["lowest", "newer", "人工智能", "深度學習"]:tokens = bpe.encode_word(w)print(f"{w} -> {tokens} -> {bpe.decode_word(tokens)}")print("\n=== 句子測試 ===")sentence = "lowest newer 人工智能深度學習"tokens_list = bpe.encode_sentence(sentence)print(tokens_list)print(bpe.decode_sentence(tokens_list))# === 單詞測試 ===
# lowest -> ['l', 'o', 'w', 'e', 's', 't', '</w>'] -> lowest
# newer -> ['n', 'e', 'w', 'e', 'r', '</w>'] -> newer
# 人工智能 -> ['人工智能', '</w>'] -> 人工智能
# 深度學習 -> ['深度', '學習', '</w>'] -> 深度學習# === 句子測試 ===
# [['l', 'o', 'w', 'e', 's', 't', '</w>'], ['n', 'e', 'w', 'e', 'r', '</w>'], ['人工智能', '王', '贊', '深度', '學習', '</w>']]
# lowest newer 人工智能深度學習