斯坦福 CS336 動手大語言模型 Assignment1 BPE Tokenizer TransformerLM

所有代碼更新至 https://github.com/WangYuHang-cmd/CS336/tree/main/assignment1-basics

作業文件結構:

CS336/assignment1-basics/
├── tests/                    # 測試文件目錄
│   ├── adapters.py          # 適配器測試
│   ├── conftest.py          # pytest配置
│   ├── __init__.py          # 包初始化
│   ├── snapshots/           # 測試快照
│   ├── test_data.py         # 數據處理測試
│   ├── test_model.py        # 模型測試
│   ├── test_nn_utils.py     # 神經網絡工具測試
│   ├── test_optimizer.py    # 優化器測試
│   ├── test_serialization.py # 序列化測試
│   ├── test_tokenizer.py    # 分詞器測試
│   └── test_train_bpe.py    # BPE訓練測試
│
└── cs336_basics/            # 實現文件目錄├── attention.py         # 注意力機制實現├── embedding.py         # 嵌入層實現├── linear.py           # 線性層實現├── optimizer.py        # 優化器實現├── tokenizer.py        # 分詞器實現├── transformerLM.py    # Transformer語言模型├── rope.py             # RoPE位置編碼├── rmsnorm.py          # RMSNorm層├── softmax.py          # Softmax實現├── swiglu.py           # SwiGLU激活函數├── utils.py            # 工具函數└── debug_*.py          # 調試文件

BPE Tokenizer

BPE Class

首先是BPE類, 我們需要正確處理作業已經定義好的接口:

class BPETokenizer:def __init__(self, vocab_size: int, special_tokens: list[str] | None = None):self.vocab_size = vocab_sizeself.special_tokens = special_tokens or []self.special_tokens_bytes = [token.encode("utf-8") for token in self.special_tokens]self.merges: List[Tuple[bytes, bytes]] = []self.stoi: Dict[bytes, int] = {}self.itos: Dict[int, bytes] = {}self.merges_rank: Dict[Tuple[bytes, bytes], int] = {}# init vocabfor i, token_bytes in enumerate(self.special_tokens_bytes):  # special tokensself.stoi[token_bytes] = iself.itos[i] = token_bytesoffset = len(self.special_tokens_bytes)for i in range(256):self.stoi[bytes([i])] = i + offsetself.itos[i + offset] = bytes([i])self.vocab = self.itos.copy()  # for serializationself.merges_rank = {}  # for fast lookup# pair2new: (p1, p2) -> new_token_idself.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}

其中stoi用來記錄每一個toekn對應的token id, itos用來記錄每一個token id對應的token, 在初始化的時候我們需要首先載入所有的special_tokens然后再依次將0-255對應字節值載入。

BPE Training

BPE Tokenizer是一個從data中進行學習的一個分詞器,其以Byte為單位進行學習, 然后最終學校的結果包括了單詞,詞根等各種各樣的形式。

BPE Tokenizer的核心就是首先經過預分詞得到一個token列表, 此時全文被拆成了多個pre_token組成的列表, 然后對這個列表中的special token進行提取(special token不參與合并),我們得到由一整個大列表拆出來的多個小列表,然后我們需要依次統計每一個小列表中的前后相鄰的字符pair的個數并計數, 然后按照以下規則進行合并:

1. 首先找到pair計數最多的pair <token1, token2>, 可能會有多個一樣數量的pair
2. 然后優先找token1字典序更大的,進行合并
3. 其次找token2字典序更大的進行合并
  • Pre_tokenize

pretokenize這個函數主要用來將文本切分成規范的詞塊列表,例如

GPT2_SPLIT_PATTERN = (r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pretokenize(text: str) -> list[bytes]:str_tokens = re.findall(GPT2_SPLIT_PATTERN, text)byte_tokens = [s.encode("utf-8") for s in str_tokens]return byte_tokens

例如"Hello world, this is user-123!" 會被pretokenize轉換為 [‘Hello’, ’ world’, ‘,’, ’ this’, ’ is’, ’ user’, ‘-’, ‘123’, ‘!’]

Train

我們可以很輕易寫出一個暴力的訓練方法(見代碼中的slow_train函數), 在這個函數中我們
num_merges_needed = self.vocab_size - len(self.stoi) # 需要合并的次數, 每一次合并會擴大vocab_size

for merge_cnt in range(num_merges_needed):pair_counts = self._get_stats(token_groups) # 遍歷當前的訓練列表,統計所有相鄰token_id的個數best_pair = max( pair_counts,key=lambda p: (pair_counts[p], self.itos[p[0]], self.itos[p[1]]),)  # 按照合并規則找到需要合并的pair# 更新合并后的所有字典new_token_id = len(self.itos)p1_bytes, p2_bytes = self.itos[best_pair[0]], self.itos[best_pair[1]]new_token_bytes = p1_bytes + p2_bytesself.merges.append((p1_bytes, p2_bytes))self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytes...

但是這不僅有可能無法通過tests/test_train_bpe.py::test_train_bpe_speed測試(我的暴力解法大約使用了5.8s遠大于限制的1.5秒), 在tests/test_train_bpe.py::test_train_bpe_special_tokens 測試中大約使用了將近7分鐘。

==================================================== 1 failed, 2 passed in 476.21s (0:07:56) ====================================================

因此我們需要考慮優化這個合并的過程:耗時的大頭是 1. “每一次都需要重新統計所有pair” 2.“更新后需要每一次重寫當前的token_id序列”, 而這些都可以通過數據結構來優化:對于token_id序列我們可以使用雙向鏈表來構建,然后對于每一個token_id對應的列表的節點位置我們可以存儲到token_id為key的set中。然后我們只需要掃一遍整個token_id的序列, 記錄每一個pair的個數然后全部push到一個堆中, 這個堆每一次會從堆頂優先pop出我們需要合并的pair. 記住,這里我們并不需要在合并pair后從內部修改這個堆,我們只需要pop出來的時候判斷一下當前的pair是否存在或者其計數是否和我們的pair_counts中一致即可。此一次修改合并后我們也只需要將合并后的pair對應的計數重新push進堆即可。考慮到每一次修改的數量不會很多, 因此總的復雜度大約是nlogn級別的
綜上我們的思路是:
數據結構:

  1. 維護一個大根堆,里面維護按照BPE合并的順序進行排序的token_id pair
  2. 維護一個雙向鏈表 用來記錄當前的token_id序列
  3. 維護一個為每一個token_id維護一個set,用來存儲每一個token_id對應的所有的雙向鏈表的節點的位置

更新方式

1. 從heap頂部取出token_id的pair,判斷是否和pair_count中記錄的數量一致,若不一致則找下一個,直到一致為止,此時就是需要合并的BPE Pair
2. 通過heap中記錄的鏈表節電找到當前當前pair的 pos_idx, nxt_idx然后找到向前向后的鏈表pre_idx和nnxt_idxpre_idx <-> pos_idx <-> nxt_idx <-> nnxt_idx 我們合并后會變成pre_idx <-> (pos_idx,nxt_idx) <-> nnxt_idxnew_token = token[pos_idx] + token[nxt_idx]3. 更新pair_count,遍歷pos[token[pos_idx]]的所有鏈表節點,找到所有nxt[]對應的token_id是token[nxt_idx]的位置,然后刪除這些位置pair_count[(token[pre_idx], token[pos_idx])] - 1    pair_count[(token[pre_idx], new_token)] + 1pair_count[(token[nxt_idx], token[nnxt_idx])] - 1    pair_count[(new_token, token[nnxt_idx])] + 1pos[new_token].add(pos_idx)pre[nnxt_idx] = pos_idxnxt[pos_idx] = nnxt_idxpre[nxt_idx] = nxt[nxt_idx] = None # 刪除被合并的pair中靠后的那一個token對應的鏈表

由于python中的heapq默認使用小根堆, 因此我們需要重寫一個類來實現大根堆

class PairItem:def __init__(self, count, token_id1, token_id2, itos):self.count = countself.token_id1 = token_id1self.token_id2 = token_id2self.itos = itosself.bytes1 = itos[token_id1]self.bytes2 = itos[token_id2]def __lt__(self, other):# 首先按頻次降序(大的在前)if self.count != other.count:return self.count > other.count# 頻次相同時,按第一個token的字節降序if self.bytes1 != other.bytes1:return self.bytes1 > other.bytes1# 第一個token相同時,按第二個token的字節降序return self.bytes2 > other.bytes2def __eq__(self, other):return (self.count == other.count and self.bytes1 == other.bytes1 and self.bytes2 == other.bytes2)def get_pair(self):return (self.token_id1, self.token_id2)

然后我們讀取文本直到處理好pretokenize的結果后

# Pre-Tokenizer
assert self.vocab_size >= len(self.stoi)with open(path, "r", encoding="utf-8") as f:text = f.read()if self.special_tokens:  # Special Tokenspecial_pattern = f"({'|'.join(re.escape(s) for s in self.special_tokens)})"text_parts = re.split(special_pattern, text)
else:text_parts = [text]# Pre-Tokenizer
initial_vocab_map = {v: k for k, v in self.itos.items()}
token_groups = []
for part in text_parts:if part in self.special_tokens or not part:continuewords_in_bytes = pretokenize(part)for word in words_in_bytes:token_groups.append([initial_vocab_map[bytes([b])] for b in word])

首先只需要掃一遍整體的token_id序列進行統計:

# BPE Merge
idx = 0
pair_counts = {}
token = {}
pre = {}
nxt = {}
pos = {}for i, token_lst in enumerate(token_groups):if not token_lst or len(token_lst) <= 1:continuetoken_lst_len = len(token_lst)for j, token_id in enumerate(token_lst):idx += 1token[idx] = token_idnxt[idx] = None if j == token_lst_len - 1 else idx + 1pre[idx] = None if j == 0 else idx - 1if j == token_lst_len - 1:continuetoken_pair = (token_id, token_lst[j + 1])pair_counts[token_pair] = pair_counts.get(token_pair, 0) + 1if pos.get(token_pair) is None:pos[token_pair] = set()pos[token_pair].add(idx)heap = []
for (a, b), cnt in pair_counts.items():item = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)
然后我們可以開始BPE Merge,merge的順序和細節需要十分注意,尤其是更新的順序和對于是否更新的還存在的pair的判斷
def update_pair(pair: tuple[int, int], delta: int, pos_idx: int | None = None):if pair is None or None in pair: returnpair_counts[pair] = pair_counts.get(pair, 0) + deltacnt = pair_counts[pair]if cnt <= 0:pair_counts.pop(pair, None)pos.pop(pair, None)returnif pos_idx is not None:ds = pos.setdefault(pair, set())if delta > 0:ds.add(pos_idx)elif delta < 0:ds.discard(pos_idx)a, b = pairitem = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)num_merges_needed = self.vocab_size - len(self.stoi)
while num_merges_needed > 0 and heap:if not pair_counts: breaknum_merges_needed -= 1while heap:item = heapq.heappop(heap)p1, p2 = item.get_pair()# 檢查這個 pair 是否仍然有效if (p1, p2) not in pair_counts or pair_counts[(p1, p2)] != item.count:continue  # 已經被合并過了# merge the new tokenself.merges.append((self.itos[p1], self.itos[p2]))p1_bytes, p2_bytes = self.itos[p1], self.itos[p2]new_token_bytes = p1_bytes + p2_bytesnew_token_id = (len(self.stoi)if self.stoi.get(new_token_bytes) is Noneelse self.stoi[new_token_bytes])self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytespos_lst = list(pos.get((p1, p2), set()))# modify the token groupfor pos_idx in pos_lst:pre_idx = pre[pos_idx]nxt_idx = nxt[pos_idx]nnxt_idx = nxt[nxt_idx] if nxt_idx is not None else Noneif nxt_idx is None or token[pos_idx] != p1 or token[nxt_idx] != p2: continueif pre_idx is not None:nxt[pre_idx] = pos_idx  # keep unchangedupdate_pair((token[pre_idx], token[pos_idx]), -1, pre_idx)update_pair((token[pre_idx], new_token_id), 1, pre_idx)if nnxt_idx is not None:pre[nnxt_idx] = pos_idxupdate_pair((token[nxt_idx], token[nnxt_idx]), -1, nxt_idx)update_pair((new_token_id, token[nnxt_idx]), 1, pos_idx)pre[pos_idx] = pre_idxnxt[pos_idx] = nnxt_idxtoken[pos_idx] = new_token_idtoken[nxt_idx] = None  # remove the old tokenpre[nxt_idx] = Nonenxt[nxt_idx] = Nonepair_counts.pop((p1, p2), None)pos.pop((p1, p2), None)breakself.merges_rank = {pair: i for i, pair in enumerate(self.merges)}
self.vocab = self.itos.copy()
self.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}

然后測試發現最終用時會快很多

============================================================== 3 passed in 30.85s ====================================

其中對于第一個測試從

# 暴力用時
(1752185555.6502326 - 1752185549.8956482) < 1.5 
# 優化之后
tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds

當然除了重載這個堆內的排序方式外,我們還可以手動來寫比較字符串時的一個比較方式,只不過需要注意的是我們需要在短的序列末尾補大字符直到和長的一樣長(可以手動指定max_len為一個比較大的數,這個的速度也很快)

def bytes_desc(b):return bytes(255 - x for x in b)def pair_desc(pair):a = self.itos[pair[0]]b = self.itos[pair[1]]max_len = 2a_pad = a + bytes([0] * (max_len - len(a)))b_pad = b + bytes([0] * (max_len - len(b)))return (bytes_desc(a_pad), bytes_desc(b_pad))heap = [(-cnt,  # 頻次取負,freq 高 → 數值小pair_desc((a, b)),a, b,)  # token-1 id, token-2 idfor (a, b), cnt in pair_counts.items()
]
heapq.heapify(heap)

BPE Encode & Decode

首先是Encode部分, 這個部分需要我們將輸入的文本字符串轉換為整數ID序列,然后我們需要注意在處理的時候1.特殊token優先處理:先識別并保護特殊token(如<|endoftext|>)2. 按長度排序:避免短特殊token被長特殊token包含的情況 3.分段處理:將文本分割為特殊token和普通文本段落.

我們首先來完成不含有special token的encoder:

    def _encode_ordinary_text(self, text_bytes: bytes) -> list[int]:if not text_bytes:return []try:text = text_bytes.decode("utf-8")except UnicodeDecodeError:text = text_bytes.decode("utf-8", errors="replace")ids_out = array("H")  # uint16 足夠 ≤ 65k vocabpair_rank = self.merges_rankpair2new = self.pair2newbyte2id = self.stoi  # 局部 alias,加速# 逐個“詞塊”處理,避免一次性 listfor word_b in iter_pretokenize(text):token_ids = array("H", (byte2id[bytes([b])] for b in word_b))# b. 就地合并:“greedy smallest-rank merge”while True:best_rank = 1000000000best_pos = -1# ——— 找當前序列里 rank 最小的 pair ———for i in range(len(token_ids) - 1):r = pair_rank.get( # ——— 替換 best_pos & best_pos+1 為新的 token ———(self.itos[token_ids[i]], self.itos[token_ids[i + 1]]),1000000000,)if r < best_rank:best_rank, best_pos = r, iif best_pos == -1:breaknew_id = pair2new[(self.itos[token_ids[best_pos]], self.itos[token_ids[best_pos + 1]])]token_ids[best_pos : best_pos + 2] = array("H", [new_id])ids_out.extend(token_ids)# array → listreturn ids_out.tolist()

在這里我使用了array而不是list,這樣每個token_id只占用2字節,逐個字符處理是防止內存爆炸
然后處理帶有特殊字符的encoder:

    def encode(self, text: str) -> list[int]:"""Encode str"""if not text:return []sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)if not sorted_special_tokens:return self._encode_ordinary_text(text.encode("utf-8"))special_pattern = f"({'|'.join(re.escape(s) for s in sorted_special_tokens)})"text_parts = re.split(special_pattern, text)all_ids = []for part in text_parts:if part in self.special_tokens:all_ids.append(self.stoi[part.encode("utf-8")])elif part:all_ids.extend(self._encode_ordinary_text(part.encode("utf-8")))return all_ids

對于decode函數則很簡單, 我們需要將一個token id序列轉換成字符串,按照BPE訓練時的合并順序:

def decode(self, ids: list[int]) -> str:all_bytes = b"".join(self.itos.get(id, b"") for id in ids)return all_bytes.decode("utf-8", errors="replace")

最后我們需要對BPETokenizer這個類進行一個序列化:

    @classmethoddef from_serialized(cls,vocab: dict[int, bytes],merges: list[tuple[bytes, bytes]],special_tokens: list[str],):instance = cls(vocab_size=len(vocab), special_tokens=special_tokens)instance.stoi = {v: k for k, v in vocab.items()}instance.itos = vocabinstance.merges = mergesinstance.merges_rank = {pair: i for i, pair in enumerate(merges)}instance.vocab = vocabinstance.pair2new = {(p1, p2): instance.stoi[p1 + p2] for (p1, p2) in merges}return instance

測試結果 (注意最后一個點的XFail是正常的 說明你沒有作弊…)

============================================================== 3 passed in 30.85s ===============================================================
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_train_bpe.py tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_tokenizer.py tests/test_tokenizer.py::test_roundtrip_empty PASSED
tests/test_tokenizer.py::test_empty_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_character PASSED
tests/test_tokenizer.py::test_single_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_unicode_character PASSED
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_ascii_string PASSED
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string PASSED
tests/test_tokenizer.py::test_unicode_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string_with_special_tokens PASSED
tests/test_tokenizer.py::test_unicode_string_with_special_tokens_matches_tiktoken PASSED
tests/test_tokenizer.py::test_overlapping_special_tokens PASSED
tests/test_tokenizer.py::test_address_roundtrip PASSED
tests/test_tokenizer.py::test_address_matches_tiktoken PASSED
tests/test_tokenizer.py::test_german_roundtrip PASSED
tests/test_tokenizer.py::test_german_matches_tiktoken PASSED
tests/test_tokenizer.py::test_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_special_token_trailing_newlines PASSED
tests/test_tokenizer.py::test_encode_special_token_double_newline_non_whitespace PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_iterable_memory_usage PASSED
tests/test_tokenizer.py::test_encode_memory_usage XFAIL (Tokenizer.encode is expected to take more memory than allotted (1MB).)========================================================= 24 passed, 1 xfailed in 4.50s =========================================================

TransformerLM

對于transformerLM我認為著一塊的難度比較常規,跟著課程的pdf照著寫就可以,不過很適合用來熟悉einops中einsum, reduce和rearrange的用法。以下是一些需要注意的地方。

Rope

這里forward可能會有精度問題,因此需要首先轉成torch.float32然后再轉回去即可

class RoPE(nn.Module):def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None,dtype=None):super().__init__()self.theta = thetaself.d_k = d_kself.max_seq_len = max_seq_lenself.half_dim = d_k // 2freq_seq = torch.arange(self.half_dim, dtype=torch.float32, device=device)inv_freq = 1.0 / (theta ** (freq_seq / self.half_dim))t = torch.arange(max_seq_len, dtype=torch.float32, device=device)freqs = einsum(t, inv_freq, "i, j -> i j")cos = torch.cos(freqs)sin = torch.sin(freqs)self.register_buffer("cos_cached", cos, persistent=False)self.register_buffer("sin_cached", sin, persistent=False)def forward(self,x: Float[Tensor, "... seq_len d_k"],token_positions: Int[Tensor, "... seq_len"],) -> Float[Tensor, "...  seq_len d_k"]:assert x.shape[-1] == self.d_k, f"x's last dim {x.shape[-1]} != d_k {self.d_k}"assert self.d_k % 2 == 0, "d_k must be even for RoPE"in_type = x.dtypex = x.to(torch.float32)# (... seq_len d_k) ->  (... seq_len d_pair 2) 2D-Tensorx_pair = rearrange(x, "... seq_len (d_pair two) -> ... seq_len d_pair two", two = 2)# cos/sin tensor buildcos = self.cos_cached[token_positions]sin = self.sin_cached[token_positions]rot_mat = torch.stack((torch.stack((cos, -sin), dim = -1),torch.stack((sin, cos), dim = -1),),dim = -2,)# rotate "i j, j -> i"x_rot = einsum(rot_mat, x_pair, "... d_pair i j, ... d_pair j -> ... d_pair i")out = rearrange(x_rot, "... seq_len d_pair two -> ... seq_len (d_pair two)", two = 2)return  out.to(in_type)
TransformerBlock

TransformerBlock按照pdf的要求寫,注意模塊的復用

class TransformerBlock(nn.Module):def __init__(self,d_model: int,num_heads: int,d_ff: int,max_seq_len: int,theta: float,device=None,dtype=None,):super().__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)self.attn = MultiheadSelfAttentionWithRoPE(d_model, num_heads, max_seq_len, theta, device, dtype)self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)self.ffn = SwiGLUFFN(d_model, d_ff, device, dtype)def forward(self,x: Float[Tensor, "batch seq_len d_model"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len d_model"]:if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)x = x + self.attn(self.ln1(x), token_positions)x = x + self.ffn(self.ln2(x))return x
TransformerLM

這里最后不需要返回softmax之后的logits, 返回softmax前一層的tensor即可

class TransformerLM(nn.Module):def __init__(self,vocab_size: int,d_model: int,num_heads: int,d_ff: int,context_length: int,theta: float,num_layers: int,device=None,dtype=None,):super().__init__()self.vocab_size = vocab_sizeself.d_model = d_modelself.num_heads = num_headsself.d_ff = d_ffself.context_length = context_lengthself.theta = thetaself.num_layers = num_layersself.device = deviceself.dtype = dtypeparam_dtype = (dtypeif (dtype is not Noneand torch.is_floating_point(torch.tensor([], dtype=dtype)))else torch.float32)self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=param_dtype)self.layers = MyLayerList([TransformerBlock(d_model=d_model,num_heads=num_heads,d_ff=d_ff,max_seq_len=context_length,theta=theta,device=device,dtype=param_dtype,)for _ in range(num_layers)])self.ln_final = RMSNorm(d_model, device=device, dtype=param_dtype)self.lm_head = Linear(d_model, vocab_size, device=device, dtype=param_dtype)@torch.no_grad()def forward(self,input_indices: Int[Tensor, "batch seq_len"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len vocab_size"]:x = self.token_embeddings(input_indices)if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)for layer in self.layers:x = layer(x, token_positions)x = self.ln_final(x)logits = self.lm_head(x)return logits
get_batch

get_batch的測試寫的不是很完善,這里可以寫成保證每一個Epoch rand的數據都不重復

def get_batch(dataset: npt.NDArray,batch_size: int,context_length: int,device: torch.device = torch.device("cpu"),
) -> tuple[npt.NDArray, npt.NDArray]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device)N = data_t.numel()# starts = torch.randint(0, N - T, (B,), device=device)starts = torch.randperm(N - T, device=device)[:B]  # 無放回采樣offsets   = rearrange(torch.arange(T + 1, device=device), 'n -> 1 n')  # [1, T+1]positions = rearrange(starts, 'b -> b 1') + offsets          tokens = data_t[positions]          # [B, T+1]x, y   = tokens[:, :-1], tokens[:, 1:]   # Next token prediction [B, T]return x, yclass EpochSampler:def __init__(self, num_positions: int, device: torch.device):self.N = num_positions            self.device = deviceself._shuffle()                   def _shuffle(self):self.perm = torch.randperm(self.N, device=self.device)self.cursor = 0                   def next(self, k: int) -> torch.Tensor:if self.cursor + k > self.N: self._shuffle()idx = self.perm[self.cursor : self.cursor + k]self.cursor += kreturn idxdef get_batch_without_same(dataset: npt.NDArray,batch_size: int,context_length: int,sampler: EpochSampler,device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, torch.Tensor]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device)   # [N_total]N = data_t.numel()starts = sampler.next(B)                    # shape (B,)# offsets: [1, T+1],數值 0‥Toffsets = torch.arange(T + 1, device=device).unsqueeze(0)            # (1, T+1)# positions: broadcast → (B, T+1)positions = starts.unsqueeze(1) + offsetstokens = data_t[positions]                  # (B, T+1)x, y = tokens[:, :-1], tokens[:, 1:]        # (B, T)return x, y

此外我的代碼倉庫中還提供一些debug函數,可以用來debug tokenizer和bpe_train, 在cs336_basics文件夾下

最后帖一張全部通過的圖片:

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914342.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914342.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914342.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Cloud Gateway 實戰指南

關鍵詞&#xff1a;微服務、API網關、Spring Cloud Gateway、路由轉發、限流熔斷 ? 文章摘要 隨著互聯網應用規模的不斷擴大&#xff0c;傳統的單體架構逐漸向微服務架構轉型。在微服務架構中&#xff0c;API 網關作為系統的入口點&#xff0c;承擔了諸如請求路由、負載均衡、…

PyTorch自動微分:從基礎到實戰

目錄 1. 自動微分是什么&#xff1f; 1.1 計算圖 1.2 requires_grad 屬性 2. 標量和向量的梯度計算 2.1 標量梯度 2.2 向量梯度 3. 梯度上下文控制 3.1 禁用梯度計算 3.2 累計梯度 4. 梯度下降實戰 4.1 求函數最小值 4.2 線性回歸參數求解 5. 總結 在深度學習中&a…

Spring AI 項目實戰(十六):Spring Boot + AI + 通義萬相圖像生成工具全棧項目實戰(附完整源碼)

系列文章 序號文章名稱1Spring AI 項目實戰(一):Spring AI 核心模塊入門2Spring AI 項目實戰(二):Spring Boot + AI + DeepSeek 深度實戰(附完整源碼)3Spring AI 項目實戰(三):Spring Boot + AI + DeepSeek 打造智能客服系統(附完整源碼)4

從零到一:企業如何組建安全團隊

在這個"黑客滿天飛&#xff0c;漏洞遍地跑"的時代&#xff0c;沒有安全團隊的企業就像裸奔的勇士——雖然很有勇氣&#xff0c;但結局往往很悲慘。 &#x1f4cb; 目錄 為什么要組建安全團隊安全團隊的核心職能團隊架構設計人員配置策略技術體系建設制度流程建立實施…

業務訪問控制-ACL與包過濾

業務訪問控制-ACL與包過濾 ACL的定義及應用場景ACL&#xff08;Access Control List&#xff0c;訪問控制列表&#xff09;是用來實現數據包識別功能的&#xff1b;ACL可以應用于諸多場景&#xff1a; 包過濾功能&#xff1a;對數據包進行放通或過濾操作。NAT&#xff08;Netwo…

穿梭時空的智慧向導:Deepoc具身智能如何賦予導覽機器人“人情味”

穿梭時空的智慧向導&#xff1a;Deepoc具身智能如何賦予導覽機器人“人情味”清晨&#xff0c;當第一縷陽光透過高大的彩繪玻璃窗&#xff0c;灑在博物館光潔的地板上&#xff0c;一位特別的“館員”已悄然“蘇醒”。它沒有制服&#xff0c;卻有著清晰的指引&#xff1b;它無需…

PostgreSQL 查詢庫中所有表占用磁盤大小、表大小

SELECTn.nspname AS schema_name,c.relname AS table_name,-- 1?? 總大小&#xff08;表 toast 索引&#xff09;pg_size_pretty(pg_total_relation_size(c.oid)) AS total_size,-- 2?? 表不包含索引&#xff08;含 TOAST&#xff09;pg_size_pretty(pg_total_relation_s…

日記-生活隨想

最近鼠鼠也是來到上海打拼&#xff08;實習&#xff09;了&#xff0c;那么秉持著來都來了的原則&#xff0c;鼠鼠也是去bw逛了逛&#xff0c;雖說沒票只能在外場看看&#x1f62d;。可惜幾乎沒有多少我非常喜歡的ip&#xff0c;不由感慨現在的二次元圈已經變樣了。雖說我知道內…

串口A和S的含義以及RT的含義

A async 異步S sync 同步RT 收發U A RT 異步U SA RT 同步/異步

spring cloud負載均衡分析之FeignBlockingLoadBalancerClient、BlockingLoadBalancerClient

本文主要分析被 FeignClient 注解的接口類請求過程中負載均衡邏輯&#xff0c;流程分析使用的依賴版本信息如下&#xff1a;<spring-boot.version>3.2.1</spring-boot.version><spring-cloud.version>2023.0.0</spring-cloud.version><com.alibaba.…

ref 和 reactive

文章目錄ref 和 reactive一、差異二、能否替代的場景分析&#xff08;1&#xff09;基本類型數據&#xff08;2&#xff09;對象類型數據&#xff08;3&#xff09;數組類型數據&#xff08;4&#xff09; 需要整體替換的場景三、替代方案與兼容寫法1. 用 reactive 模擬 ref2. …

BatchNorm 與 LayerNorm:原理、實現與應用對比

BatchNorm 與 LayerNorm&#xff1a;原理、實現與應用對比 Batch Normalization (批歸一化) 和 Layer Normalization (層歸一化) 是深度學習中兩種核心的歸一化技術&#xff0c;它們解決了神經網絡訓練中的內部協變量偏移問題&#xff0c;大幅提升了模型訓練的穩定性和收斂速度…

OcsNG基于debian一鍵部署腳本

&#x1f914; 為什么有了GLPI還要部署OCS-NG&#xff1f; 核心問題&#xff1a;數據收集的風險 GLPI直接收集的問題&#xff1a; Agent直接向GLPI報告數據時&#xff0c;任何收集異常都會直接影響資產數據庫網絡問題、Agent故障可能導致重復資產、錯誤數據、資產丟失無法對收集…

001_Claude開發者指南介紹

Claude開發者指南介紹 目錄 Claude簡介Claude 4 模型開始使用核心功能支持資源 Claude簡介 Claude 是由 Anthropic 構建的高性能、可信賴和智能的 AI 平臺。Claude 具備出色的語言、推理、分析和編程能力&#xff0c;可以幫助您解決各種復雜任務。 想要與 Claude 聊天嗎&a…

004_Claude功能特性與API使用

Claude功能特性與API使用 目錄 API 基礎使用核心功能特性高級功能開發工具平臺支持 API 基礎使用 快速開始 通過 Anthropic Console 獲取 API 訪問權限&#xff1a; 在 console.anthropic.com/account/keys 生成 API 密鑰使用 Workbench 在瀏覽器中測試 API 認證方式 H…

ReAct論文解讀(1)—什么是ReAct?

什么是ReAct&#xff1f; 在大語言模型&#xff08;LLM&#xff09;領域中&#xff0c;ReAct 指的是一種結合了推理&#xff08;Reasoning&#xff09; 和行動&#xff08;Acting&#xff09; 的提示方法&#xff0c;全稱是 “ReAct: Synergizing Reasoning and Acting in Lan…

【云服務器安全相關】服務器防火墻常見系統日志信息說明

目錄? 一、防火墻日志是做什么的&#xff1f;&#x1f6e0;? 二、常見防火墻日志信息及說明&#x1f9ea; 三、典型日志示例解析1. 被阻斷的訪問&#xff08;DROP&#xff09;2. 被允許的訪問&#xff08;ACCEPT&#xff09;3. 被拒絕的端口訪問4. 可疑端口掃描行為&#x1f…

011_視覺能力與圖像處理

視覺能力與圖像處理 目錄 視覺能力概述支持的圖像格式圖像上傳方式使用限制最佳實踐應用場景API使用示例視覺能力概述 多模態交互 Claude 3 系列模型具備強大的視覺理解能力,可以分析和理解圖像內容,實現真正的多模態AI交互。這種能力使Claude能夠: 圖像內容分析:理解圖…

ansible自動化部署考試系統前后端分離項目

1. ?ansible編寫劇本步驟1??創建roles目錄結構2??在group_vars/all/main.yml中定義變量列表3??在tasks目錄下編寫tasks任務4??在files目錄下準備部署文件5??在templates目錄下創建j2模板文件6??在handlers目錄下編寫handlers7??在roles目錄下編寫主playbook8??…

【AI論文】GLM-4.1V-Thinking:邁向具備可擴展強化學習的通用多模態推理

摘要&#xff1a;我們推出GLM-4.1V-Thinking&#xff0c;這是一款旨在推動通用多模態推理發展的視覺語言模型&#xff08;VLM&#xff09;。在本報告中&#xff0c;我們分享了在以推理為核心的訓練框架開發過程中的關鍵發現。我們首先通過大規模預訓練開發了一個具備顯著潛力的…