所有代碼更新至 https://github.com/WangYuHang-cmd/CS336/tree/main/assignment1-basics
作業文件結構:
CS336/assignment1-basics/
├── tests/ # 測試文件目錄
│ ├── adapters.py # 適配器測試
│ ├── conftest.py # pytest配置
│ ├── __init__.py # 包初始化
│ ├── snapshots/ # 測試快照
│ ├── test_data.py # 數據處理測試
│ ├── test_model.py # 模型測試
│ ├── test_nn_utils.py # 神經網絡工具測試
│ ├── test_optimizer.py # 優化器測試
│ ├── test_serialization.py # 序列化測試
│ ├── test_tokenizer.py # 分詞器測試
│ └── test_train_bpe.py # BPE訓練測試
│
└── cs336_basics/ # 實現文件目錄├── attention.py # 注意力機制實現├── embedding.py # 嵌入層實現├── linear.py # 線性層實現├── optimizer.py # 優化器實現├── tokenizer.py # 分詞器實現├── transformerLM.py # Transformer語言模型├── rope.py # RoPE位置編碼├── rmsnorm.py # RMSNorm層├── softmax.py # Softmax實現├── swiglu.py # SwiGLU激活函數├── utils.py # 工具函數└── debug_*.py # 調試文件
BPE Tokenizer
BPE Class
首先是BPE類, 我們需要正確處理作業已經定義好的接口:
class BPETokenizer:def __init__(self, vocab_size: int, special_tokens: list[str] | None = None):self.vocab_size = vocab_sizeself.special_tokens = special_tokens or []self.special_tokens_bytes = [token.encode("utf-8") for token in self.special_tokens]self.merges: List[Tuple[bytes, bytes]] = []self.stoi: Dict[bytes, int] = {}self.itos: Dict[int, bytes] = {}self.merges_rank: Dict[Tuple[bytes, bytes], int] = {}# init vocabfor i, token_bytes in enumerate(self.special_tokens_bytes): # special tokensself.stoi[token_bytes] = iself.itos[i] = token_bytesoffset = len(self.special_tokens_bytes)for i in range(256):self.stoi[bytes([i])] = i + offsetself.itos[i + offset] = bytes([i])self.vocab = self.itos.copy() # for serializationself.merges_rank = {} # for fast lookup# pair2new: (p1, p2) -> new_token_idself.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}
其中stoi用來記錄每一個toekn對應的token id, itos用來記錄每一個token id對應的token, 在初始化的時候我們需要首先載入所有的special_tokens然后再依次將0-255對應字節值載入。
BPE Training
BPE Tokenizer是一個從data中進行學習的一個分詞器,其以Byte為單位進行學習, 然后最終學校的結果包括了單詞,詞根等各種各樣的形式。
BPE Tokenizer的核心就是首先經過預分詞得到一個token列表, 此時全文被拆成了多個pre_token組成的列表, 然后對這個列表中的special token進行提取(special token不參與合并),我們得到由一整個大列表拆出來的多個小列表,然后我們需要依次統計每一個小列表中的前后相鄰的字符pair的個數并計數, 然后按照以下規則進行合并:
1. 首先找到pair計數最多的pair <token1, token2>, 可能會有多個一樣數量的pair
2. 然后優先找token1字典序更大的,進行合并
3. 其次找token2字典序更大的進行合并
- Pre_tokenize
pretokenize這個函數主要用來將文本切分成規范的詞塊列表,例如
GPT2_SPLIT_PATTERN = (r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pretokenize(text: str) -> list[bytes]:str_tokens = re.findall(GPT2_SPLIT_PATTERN, text)byte_tokens = [s.encode("utf-8") for s in str_tokens]return byte_tokens
例如"Hello world, this is user-123!" 會被pretokenize轉換為 [‘Hello’, ’ world’, ‘,’, ’ this’, ’ is’, ’ user’, ‘-’, ‘123’, ‘!’]
Train
我們可以很輕易寫出一個暴力的訓練方法(見代碼中的slow_train函數), 在這個函數中我們
num_merges_needed = self.vocab_size - len(self.stoi) # 需要合并的次數, 每一次合并會擴大vocab_size
for merge_cnt in range(num_merges_needed):pair_counts = self._get_stats(token_groups) # 遍歷當前的訓練列表,統計所有相鄰token_id的個數best_pair = max( pair_counts,key=lambda p: (pair_counts[p], self.itos[p[0]], self.itos[p[1]]),) # 按照合并規則找到需要合并的pair# 更新合并后的所有字典new_token_id = len(self.itos)p1_bytes, p2_bytes = self.itos[best_pair[0]], self.itos[best_pair[1]]new_token_bytes = p1_bytes + p2_bytesself.merges.append((p1_bytes, p2_bytes))self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytes...
但是這不僅有可能無法通過tests/test_train_bpe.py::test_train_bpe_speed測試(我的暴力解法大約使用了5.8s遠大于限制的1.5秒), 在tests/test_train_bpe.py::test_train_bpe_special_tokens 測試中大約使用了將近7分鐘。
==================================================== 1 failed, 2 passed in 476.21s (0:07:56) ====================================================
因此我們需要考慮優化這個合并的過程:耗時的大頭是 1. “每一次都需要重新統計所有pair” 2.“更新后需要每一次重寫當前的token_id序列”, 而這些都可以通過數據結構來優化:對于token_id序列我們可以使用雙向鏈表來構建,然后對于每一個token_id對應的列表的節點位置我們可以存儲到token_id為key的set中。然后我們只需要掃一遍整個token_id的序列, 記錄每一個pair的個數然后全部push到一個堆中, 這個堆每一次會從堆頂優先pop出我們需要合并的pair. 記住,這里我們并不需要在合并pair后從內部修改這個堆,我們只需要pop出來的時候判斷一下當前的pair是否存在或者其計數是否和我們的pair_counts中一致即可。此一次修改合并后我們也只需要將合并后的pair對應的計數重新push進堆即可。考慮到每一次修改的數量不會很多, 因此總的復雜度大約是nlogn級別的
綜上我們的思路是:
數據結構:
- 維護一個大根堆,里面維護按照BPE合并的順序進行排序的token_id pair
- 維護一個雙向鏈表 用來記錄當前的token_id序列
- 維護一個為每一個token_id維護一個set,用來存儲每一個token_id對應的所有的雙向鏈表的節點的位置
更新方式:
1. 從heap頂部取出token_id的pair,判斷是否和pair_count中記錄的數量一致,若不一致則找下一個,直到一致為止,此時就是需要合并的BPE Pair
2. 通過heap中記錄的鏈表節電找到當前當前pair的 pos_idx, nxt_idx然后找到向前向后的鏈表pre_idx和nnxt_idxpre_idx <-> pos_idx <-> nxt_idx <-> nnxt_idx 我們合并后會變成pre_idx <-> (pos_idx,nxt_idx) <-> nnxt_idxnew_token = token[pos_idx] + token[nxt_idx]3. 更新pair_count,遍歷pos[token[pos_idx]]的所有鏈表節點,找到所有nxt[]對應的token_id是token[nxt_idx]的位置,然后刪除這些位置pair_count[(token[pre_idx], token[pos_idx])] - 1 pair_count[(token[pre_idx], new_token)] + 1pair_count[(token[nxt_idx], token[nnxt_idx])] - 1 pair_count[(new_token, token[nnxt_idx])] + 1pos[new_token].add(pos_idx)pre[nnxt_idx] = pos_idxnxt[pos_idx] = nnxt_idxpre[nxt_idx] = nxt[nxt_idx] = None # 刪除被合并的pair中靠后的那一個token對應的鏈表
由于python中的heapq默認使用小根堆, 因此我們需要重寫一個類來實現大根堆
class PairItem:def __init__(self, count, token_id1, token_id2, itos):self.count = countself.token_id1 = token_id1self.token_id2 = token_id2self.itos = itosself.bytes1 = itos[token_id1]self.bytes2 = itos[token_id2]def __lt__(self, other):# 首先按頻次降序(大的在前)if self.count != other.count:return self.count > other.count# 頻次相同時,按第一個token的字節降序if self.bytes1 != other.bytes1:return self.bytes1 > other.bytes1# 第一個token相同時,按第二個token的字節降序return self.bytes2 > other.bytes2def __eq__(self, other):return (self.count == other.count and self.bytes1 == other.bytes1 and self.bytes2 == other.bytes2)def get_pair(self):return (self.token_id1, self.token_id2)
然后我們讀取文本直到處理好pretokenize的結果后
# Pre-Tokenizer
assert self.vocab_size >= len(self.stoi)with open(path, "r", encoding="utf-8") as f:text = f.read()if self.special_tokens: # Special Tokenspecial_pattern = f"({'|'.join(re.escape(s) for s in self.special_tokens)})"text_parts = re.split(special_pattern, text)
else:text_parts = [text]# Pre-Tokenizer
initial_vocab_map = {v: k for k, v in self.itos.items()}
token_groups = []
for part in text_parts:if part in self.special_tokens or not part:continuewords_in_bytes = pretokenize(part)for word in words_in_bytes:token_groups.append([initial_vocab_map[bytes([b])] for b in word])
首先只需要掃一遍整體的token_id序列進行統計:
# BPE Merge
idx = 0
pair_counts = {}
token = {}
pre = {}
nxt = {}
pos = {}for i, token_lst in enumerate(token_groups):if not token_lst or len(token_lst) <= 1:continuetoken_lst_len = len(token_lst)for j, token_id in enumerate(token_lst):idx += 1token[idx] = token_idnxt[idx] = None if j == token_lst_len - 1 else idx + 1pre[idx] = None if j == 0 else idx - 1if j == token_lst_len - 1:continuetoken_pair = (token_id, token_lst[j + 1])pair_counts[token_pair] = pair_counts.get(token_pair, 0) + 1if pos.get(token_pair) is None:pos[token_pair] = set()pos[token_pair].add(idx)heap = []
for (a, b), cnt in pair_counts.items():item = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)
然后我們可以開始BPE Merge,merge的順序和細節需要十分注意,尤其是更新的順序和對于是否更新的還存在的pair的判斷
def update_pair(pair: tuple[int, int], delta: int, pos_idx: int | None = None):if pair is None or None in pair: returnpair_counts[pair] = pair_counts.get(pair, 0) + deltacnt = pair_counts[pair]if cnt <= 0:pair_counts.pop(pair, None)pos.pop(pair, None)returnif pos_idx is not None:ds = pos.setdefault(pair, set())if delta > 0:ds.add(pos_idx)elif delta < 0:ds.discard(pos_idx)a, b = pairitem = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)num_merges_needed = self.vocab_size - len(self.stoi)
while num_merges_needed > 0 and heap:if not pair_counts: breaknum_merges_needed -= 1while heap:item = heapq.heappop(heap)p1, p2 = item.get_pair()# 檢查這個 pair 是否仍然有效if (p1, p2) not in pair_counts or pair_counts[(p1, p2)] != item.count:continue # 已經被合并過了# merge the new tokenself.merges.append((self.itos[p1], self.itos[p2]))p1_bytes, p2_bytes = self.itos[p1], self.itos[p2]new_token_bytes = p1_bytes + p2_bytesnew_token_id = (len(self.stoi)if self.stoi.get(new_token_bytes) is Noneelse self.stoi[new_token_bytes])self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytespos_lst = list(pos.get((p1, p2), set()))# modify the token groupfor pos_idx in pos_lst:pre_idx = pre[pos_idx]nxt_idx = nxt[pos_idx]nnxt_idx = nxt[nxt_idx] if nxt_idx is not None else Noneif nxt_idx is None or token[pos_idx] != p1 or token[nxt_idx] != p2: continueif pre_idx is not None:nxt[pre_idx] = pos_idx # keep unchangedupdate_pair((token[pre_idx], token[pos_idx]), -1, pre_idx)update_pair((token[pre_idx], new_token_id), 1, pre_idx)if nnxt_idx is not None:pre[nnxt_idx] = pos_idxupdate_pair((token[nxt_idx], token[nnxt_idx]), -1, nxt_idx)update_pair((new_token_id, token[nnxt_idx]), 1, pos_idx)pre[pos_idx] = pre_idxnxt[pos_idx] = nnxt_idxtoken[pos_idx] = new_token_idtoken[nxt_idx] = None # remove the old tokenpre[nxt_idx] = Nonenxt[nxt_idx] = Nonepair_counts.pop((p1, p2), None)pos.pop((p1, p2), None)breakself.merges_rank = {pair: i for i, pair in enumerate(self.merges)}
self.vocab = self.itos.copy()
self.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}
然后測試發現最終用時會快很多
============================================================== 3 passed in 30.85s ====================================
其中對于第一個測試從
# 暴力用時
(1752185555.6502326 - 1752185549.8956482) < 1.5
# 優化之后
tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
當然除了重載這個堆內的排序方式外,我們還可以手動來寫比較字符串時的一個比較方式,只不過需要注意的是我們需要在短的序列末尾補大字符直到和長的一樣長(可以手動指定max_len為一個比較大的數,這個的速度也很快)
def bytes_desc(b):return bytes(255 - x for x in b)def pair_desc(pair):a = self.itos[pair[0]]b = self.itos[pair[1]]max_len = 2a_pad = a + bytes([0] * (max_len - len(a)))b_pad = b + bytes([0] * (max_len - len(b)))return (bytes_desc(a_pad), bytes_desc(b_pad))heap = [(-cnt, # 頻次取負,freq 高 → 數值小pair_desc((a, b)),a, b,) # token-1 id, token-2 idfor (a, b), cnt in pair_counts.items()
]
heapq.heapify(heap)
BPE Encode & Decode
首先是Encode部分, 這個部分需要我們將輸入的文本字符串轉換為整數ID序列,然后我們需要注意在處理的時候1.特殊token優先處理:先識別并保護特殊token(如<|endoftext|>)2. 按長度排序:避免短特殊token被長特殊token包含的情況 3.分段處理:將文本分割為特殊token和普通文本段落.
我們首先來完成不含有special token的encoder:
def _encode_ordinary_text(self, text_bytes: bytes) -> list[int]:if not text_bytes:return []try:text = text_bytes.decode("utf-8")except UnicodeDecodeError:text = text_bytes.decode("utf-8", errors="replace")ids_out = array("H") # uint16 足夠 ≤ 65k vocabpair_rank = self.merges_rankpair2new = self.pair2newbyte2id = self.stoi # 局部 alias,加速# 逐個“詞塊”處理,避免一次性 listfor word_b in iter_pretokenize(text):token_ids = array("H", (byte2id[bytes([b])] for b in word_b))# b. 就地合并:“greedy smallest-rank merge”while True:best_rank = 1000000000best_pos = -1# ——— 找當前序列里 rank 最小的 pair ———for i in range(len(token_ids) - 1):r = pair_rank.get( # ——— 替換 best_pos & best_pos+1 為新的 token ———(self.itos[token_ids[i]], self.itos[token_ids[i + 1]]),1000000000,)if r < best_rank:best_rank, best_pos = r, iif best_pos == -1:breaknew_id = pair2new[(self.itos[token_ids[best_pos]], self.itos[token_ids[best_pos + 1]])]token_ids[best_pos : best_pos + 2] = array("H", [new_id])ids_out.extend(token_ids)# array → listreturn ids_out.tolist()
在這里我使用了array而不是list,這樣每個token_id只占用2字節,逐個字符處理是防止內存爆炸
然后處理帶有特殊字符的encoder:
def encode(self, text: str) -> list[int]:"""Encode str"""if not text:return []sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)if not sorted_special_tokens:return self._encode_ordinary_text(text.encode("utf-8"))special_pattern = f"({'|'.join(re.escape(s) for s in sorted_special_tokens)})"text_parts = re.split(special_pattern, text)all_ids = []for part in text_parts:if part in self.special_tokens:all_ids.append(self.stoi[part.encode("utf-8")])elif part:all_ids.extend(self._encode_ordinary_text(part.encode("utf-8")))return all_ids
對于decode函數則很簡單, 我們需要將一個token id序列轉換成字符串,按照BPE訓練時的合并順序:
def decode(self, ids: list[int]) -> str:all_bytes = b"".join(self.itos.get(id, b"") for id in ids)return all_bytes.decode("utf-8", errors="replace")
最后我們需要對BPETokenizer這個類進行一個序列化:
@classmethoddef from_serialized(cls,vocab: dict[int, bytes],merges: list[tuple[bytes, bytes]],special_tokens: list[str],):instance = cls(vocab_size=len(vocab), special_tokens=special_tokens)instance.stoi = {v: k for k, v in vocab.items()}instance.itos = vocabinstance.merges = mergesinstance.merges_rank = {pair: i for i, pair in enumerate(merges)}instance.vocab = vocabinstance.pair2new = {(p1, p2): instance.stoi[p1 + p2] for (p1, p2) in merges}return instance
測試結果 (注意最后一個點的XFail是正常的 說明你沒有作弊…)
============================================================== 3 passed in 30.85s ===============================================================
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_train_bpe.py tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_tokenizer.py tests/test_tokenizer.py::test_roundtrip_empty PASSED
tests/test_tokenizer.py::test_empty_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_character PASSED
tests/test_tokenizer.py::test_single_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_unicode_character PASSED
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_ascii_string PASSED
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string PASSED
tests/test_tokenizer.py::test_unicode_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string_with_special_tokens PASSED
tests/test_tokenizer.py::test_unicode_string_with_special_tokens_matches_tiktoken PASSED
tests/test_tokenizer.py::test_overlapping_special_tokens PASSED
tests/test_tokenizer.py::test_address_roundtrip PASSED
tests/test_tokenizer.py::test_address_matches_tiktoken PASSED
tests/test_tokenizer.py::test_german_roundtrip PASSED
tests/test_tokenizer.py::test_german_matches_tiktoken PASSED
tests/test_tokenizer.py::test_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_special_token_trailing_newlines PASSED
tests/test_tokenizer.py::test_encode_special_token_double_newline_non_whitespace PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_iterable_memory_usage PASSED
tests/test_tokenizer.py::test_encode_memory_usage XFAIL (Tokenizer.encode is expected to take more memory than allotted (1MB).)========================================================= 24 passed, 1 xfailed in 4.50s =========================================================
TransformerLM
對于transformerLM我認為著一塊的難度比較常規,跟著課程的pdf照著寫就可以,不過很適合用來熟悉einops中einsum, reduce和rearrange的用法。以下是一些需要注意的地方。
Rope
這里forward可能會有精度問題,因此需要首先轉成torch.float32然后再轉回去即可
class RoPE(nn.Module):def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None,dtype=None):super().__init__()self.theta = thetaself.d_k = d_kself.max_seq_len = max_seq_lenself.half_dim = d_k // 2freq_seq = torch.arange(self.half_dim, dtype=torch.float32, device=device)inv_freq = 1.0 / (theta ** (freq_seq / self.half_dim))t = torch.arange(max_seq_len, dtype=torch.float32, device=device)freqs = einsum(t, inv_freq, "i, j -> i j")cos = torch.cos(freqs)sin = torch.sin(freqs)self.register_buffer("cos_cached", cos, persistent=False)self.register_buffer("sin_cached", sin, persistent=False)def forward(self,x: Float[Tensor, "... seq_len d_k"],token_positions: Int[Tensor, "... seq_len"],) -> Float[Tensor, "... seq_len d_k"]:assert x.shape[-1] == self.d_k, f"x's last dim {x.shape[-1]} != d_k {self.d_k}"assert self.d_k % 2 == 0, "d_k must be even for RoPE"in_type = x.dtypex = x.to(torch.float32)# (... seq_len d_k) -> (... seq_len d_pair 2) 2D-Tensorx_pair = rearrange(x, "... seq_len (d_pair two) -> ... seq_len d_pair two", two = 2)# cos/sin tensor buildcos = self.cos_cached[token_positions]sin = self.sin_cached[token_positions]rot_mat = torch.stack((torch.stack((cos, -sin), dim = -1),torch.stack((sin, cos), dim = -1),),dim = -2,)# rotate "i j, j -> i"x_rot = einsum(rot_mat, x_pair, "... d_pair i j, ... d_pair j -> ... d_pair i")out = rearrange(x_rot, "... seq_len d_pair two -> ... seq_len (d_pair two)", two = 2)return out.to(in_type)
TransformerBlock
TransformerBlock按照pdf的要求寫,注意模塊的復用
class TransformerBlock(nn.Module):def __init__(self,d_model: int,num_heads: int,d_ff: int,max_seq_len: int,theta: float,device=None,dtype=None,):super().__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)self.attn = MultiheadSelfAttentionWithRoPE(d_model, num_heads, max_seq_len, theta, device, dtype)self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)self.ffn = SwiGLUFFN(d_model, d_ff, device, dtype)def forward(self,x: Float[Tensor, "batch seq_len d_model"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len d_model"]:if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)x = x + self.attn(self.ln1(x), token_positions)x = x + self.ffn(self.ln2(x))return x
TransformerLM
這里最后不需要返回softmax之后的logits, 返回softmax前一層的tensor即可
class TransformerLM(nn.Module):def __init__(self,vocab_size: int,d_model: int,num_heads: int,d_ff: int,context_length: int,theta: float,num_layers: int,device=None,dtype=None,):super().__init__()self.vocab_size = vocab_sizeself.d_model = d_modelself.num_heads = num_headsself.d_ff = d_ffself.context_length = context_lengthself.theta = thetaself.num_layers = num_layersself.device = deviceself.dtype = dtypeparam_dtype = (dtypeif (dtype is not Noneand torch.is_floating_point(torch.tensor([], dtype=dtype)))else torch.float32)self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=param_dtype)self.layers = MyLayerList([TransformerBlock(d_model=d_model,num_heads=num_heads,d_ff=d_ff,max_seq_len=context_length,theta=theta,device=device,dtype=param_dtype,)for _ in range(num_layers)])self.ln_final = RMSNorm(d_model, device=device, dtype=param_dtype)self.lm_head = Linear(d_model, vocab_size, device=device, dtype=param_dtype)@torch.no_grad()def forward(self,input_indices: Int[Tensor, "batch seq_len"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len vocab_size"]:x = self.token_embeddings(input_indices)if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)for layer in self.layers:x = layer(x, token_positions)x = self.ln_final(x)logits = self.lm_head(x)return logits
get_batch
get_batch的測試寫的不是很完善,這里可以寫成保證每一個Epoch rand的數據都不重復
def get_batch(dataset: npt.NDArray,batch_size: int,context_length: int,device: torch.device = torch.device("cpu"),
) -> tuple[npt.NDArray, npt.NDArray]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device)N = data_t.numel()# starts = torch.randint(0, N - T, (B,), device=device)starts = torch.randperm(N - T, device=device)[:B] # 無放回采樣offsets = rearrange(torch.arange(T + 1, device=device), 'n -> 1 n') # [1, T+1]positions = rearrange(starts, 'b -> b 1') + offsets tokens = data_t[positions] # [B, T+1]x, y = tokens[:, :-1], tokens[:, 1:] # Next token prediction [B, T]return x, yclass EpochSampler:def __init__(self, num_positions: int, device: torch.device):self.N = num_positions self.device = deviceself._shuffle() def _shuffle(self):self.perm = torch.randperm(self.N, device=self.device)self.cursor = 0 def next(self, k: int) -> torch.Tensor:if self.cursor + k > self.N: self._shuffle()idx = self.perm[self.cursor : self.cursor + k]self.cursor += kreturn idxdef get_batch_without_same(dataset: npt.NDArray,batch_size: int,context_length: int,sampler: EpochSampler,device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, torch.Tensor]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device) # [N_total]N = data_t.numel()starts = sampler.next(B) # shape (B,)# offsets: [1, T+1],數值 0‥Toffsets = torch.arange(T + 1, device=device).unsqueeze(0) # (1, T+1)# positions: broadcast → (B, T+1)positions = starts.unsqueeze(1) + offsetstokens = data_t[positions] # (B, T+1)x, y = tokens[:, :-1], tokens[:, 1:] # (B, T)return x, y
此外我的代碼倉庫中還提供一些debug函數,可以用來debug tokenizer和bpe_train, 在cs336_basics文件夾下
最后帖一張全部通過的圖片: