LLM基礎5_從零開始實現 GPT 模型

基于GitHub項目:https://github.com/datawhalechina/llms-from-scratch-cn

設計 LLM 的架構

GPT 模型基于 Transformer 的?decoder-only?架構,其主要特點包括:

  • 順序生成文本

  • 參數數量龐大(而非代碼量復雜)

  • 大量重復的模塊化組件

以 GPT-2 small 模型(124M 參數)為例,其配置如下:

GPT_CONFIG_124M = {"vocab_size": 50257,  # BPE 分詞器詞表大小"ctx_len": 1024,      # 最大上下文長度"emb_dim": 768,       # 嵌入維度"n_heads": 12,        # 注意力頭數量"n_layers": 12,       # Transformer 塊層數"drop_rate": 0.1,     # Dropout 比例"qkv_bias": False     # QKV 計算是否使用偏置
}

GPT 模型基本結構

cfg是配置實例

import torch.nn as nnclass GPTModel(nn.Module):def __init__(self, cfg):super().__init__()# Token 嵌入層self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])# 位置嵌入層self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])# Dropout 層self.drop_emb = nn.Dropout(cfg["drop_rate"])# 堆疊n_layers相同的Transformer 塊self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])# 最終層歸一化self.final_norm = LayerNorm(cfg["emb_dim"])# 輸出層self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)def forward(self, in_idx):batch_size, seq_len = in_idx.shape# Token 嵌入tok_embeds = self.tok_emb(in_idx)# 位置嵌入pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))# 組合嵌入x = tok_embeds + pos_embedsx = self.drop_emb(x)# 通過 Transformer 塊x = self.trf_blocks(x)# 最終歸一化x = self.final_norm(x)# 輸出 logitslogits = self.out_head(x)return logits

?層歸一化 (Layer Normalization)

層歸一化將激活值規范化為均值為 0、方差為 1 的分布,加速模型收斂:

class LayerNorm(nn.Module):def __init__(self, emb_dim):super().__init__()self.eps = 1e-5    # 防止除零錯誤的標準設定值self.scale = nn.Parameter(torch.ones(emb_dim))  #可學習縮放參數,初始化為全1向量self.shift = nn.Parameter(torch.zeros(emb_dim)) #可學習平移參數,初始化為全0向量def forward(self, x):mean = x.mean(dim=-1, keepdim=True)    #計算均值 μ,沿最后一維,保持維度var = x.var(dim=-1, keepdim=True, unbiased=False)    #計算方差 σ2,同均值維度,有偏估計(分母n)norm_x = (x - mean) / torch.sqrt(var + self.eps)    #標準化計算,分母添加ε防溢出return self.scale * norm_x + self.shift    #仿射變換,恢復模型表達能力

GELU 激活函數與前饋網絡

GPT 使用 GELU(高斯誤差線性單元)激活函數:

場景ReLU 的行為GELU 的行為
處理微弱負信號直接丟棄(可能丟失細節)部分保留(如:保留 30% 的信號強度)
遇到強烈正信號完全放行幾乎完全放行(保留 95% 以上)
訓練穩定性容易在臨界點卡頓平滑過渡,減少訓練震蕩
應對復雜模式需要堆疊更多層數單層就能捕捉更細膩的變化
class GELU(nn.Module):def __init__(self):super().__init__()def forward(self, x):return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))

前饋神經網絡實現:

class FeedForward(nn.Module):def __init__(self, cfg):super().__init__()self.layers = nn.Sequential(nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),GELU(),nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),nn.Dropout(cfg["drop_rate"]))def forward(self, x):return self.layers(x)

Shortcut 連接

Shortcut 連接(殘差連接)解決深度網絡中的梯度消失問題:

class TransformerBlock(nn.Module):def __init__(self, cfg):super().__init__()self.att = MultiHeadAttention(d_in=cfg["emb_dim"],d_out=cfg["emb_dim"],block_size=cfg["ctx_len"],num_heads=cfg["n_heads"], dropout=cfg["drop_rate"],qkv_bias=cfg["qkv_bias"])self.ff = FeedForward(cfg)self.norm1 = LayerNorm(cfg["emb_dim"])self.norm2 = LayerNorm(cfg["emb_dim"])self.drop_resid = nn.Dropout(cfg["drop_rate"])def forward(self, x):# 注意力塊的殘差連接shortcut = xx = self.norm1(x)x = self.att(x)x = self.drop_resid(x)x = x + shortcut# 前饋網絡的殘差連接shortcut = xx = self.norm2(x)x = self.ff(x)x = self.drop_resid(x)x = x + shortcutreturn x

Transformer 塊整合

將多頭注意力與前饋網絡整合為 Transformer 塊:

class TransformerBlock(nn.Module):def __init__(self, cfg):super().__init__()self.att = MultiHeadAttention(d_in=cfg["emb_dim"],d_out=cfg["emb_dim"],block_size=cfg["ctx_len"],num_heads=cfg["n_heads"], dropout=cfg["drop_rate"],qkv_bias=cfg["qkv_bias"])self.ff = FeedForward(cfg)self.norm1 = LayerNorm(cfg["emb_dim"])self.norm2 = LayerNorm(cfg["emb_dim"])self.drop_resid = nn.Dropout(cfg["drop_rate"])def forward(self, x):# 注意力塊的殘差連接shortcut = xx = self.norm1(x)x = self.att(x)x = self.drop_resid(x)x = x + shortcut# 前饋網絡的殘差連接shortcut = xx = self.norm2(x)x = self.ff(x)x = self.drop_resid(x)x = x + shortcutreturn x

完整 GPT 模型實現

class GPTModel(nn.Module):def __init__(self, cfg):super().__init__()self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])self.drop_emb = nn.Dropout(cfg["drop_rate"])self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])self.final_norm = LayerNorm(cfg["emb_dim"])self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)def forward(self, in_idx):batch_size, seq_len = in_idx.shapetok_embeds = self.tok_emb(in_idx)pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))x = tok_embeds + pos_embedsx = self.drop_emb(x)x = self.trf_blocks(x)x = self.final_norm(x)logits = self.out_head(x)return logits

文本生成

使用貪婪解碼生成文本:

def generate_text_simple(model, idx, max_new_tokens, context_size):for _ in range(max_new_tokens):# 截斷超過上下文長度的部分idx_cond = idx[:, -context_size:]with torch.no_grad():logits = model(idx_cond)# 獲取最后一個 token 的 logitslogits = logits[:, -1, :]  probas = torch.softmax(logits, dim=-1)idx_next = torch.argmax(probas, dim=-1, keepdim=True)idx = torch.cat((idx, idx_next), dim=1)return idx

使用示例:

# 初始化模型
model = GPTModel(GPT_CONFIG_124M)# 設置評估模式
model.eval()# 生成文本
start_context = "Every effort moves you"
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)generated = generate_text_simple(model=model,idx=encoded_tensor,max_new_tokens=10,context_size=GPT_CONFIG_124M["ctx_len"]
)decoded_text = tokenizer.decode(generated.squeeze(0).tolist())
print(decoded_text)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/83648.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/83648.shtml
英文地址,請注明出處:http://en.pswp.cn/web/83648.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Android 中 linux 命令查詢設備信息

一、getprop 命令 在 Linux 系統中, getprop 命令通常用于獲取 Android 設備的系統屬性,這些屬性包括設備型號、Android 版本、電池狀態等。 1、獲取 Android 版本號 adb shell getprop ro.build.version.release2、獲取設備型號 adb shell getprop …

26考研 | 王道 | 計算機組成原理 | 六、總線

26考研 | 王道 | 計算機組成原理 | 六、總線 文章目錄 26考研 | 王道 | 計算機組成原理 | 六、總線6.1 總線概述1. 總線概述2. 總線的性能指標 6.2 總線仲裁(考綱沒有,看了留個印象)6.3 總線操作和定時6.4 總線標準(考綱沒有&…

SE(Secure Element)加密芯片與MCU協同工作的典型流程

以下是SE(Secure Element)加密芯片與MCU協同工作的典型流程,綜合安全認證、數據保護及防篡改機制: 一、基礎認證流程(參數保護方案) 密鑰預置? SE芯片與MCU分別預置相同的3DES密鑰(Key1、Key2…

數據庫——MongoDB

一、介紹 1. MongoDB 概述 MongoDB 是一款由 C 語言編寫的開源 NoSQL 數據庫,采用分布式文件存儲設計。作為介于關系型和非關系型數據庫之間的產品,它是 NoSQL 數據庫中最接近傳統關系數據庫的解決方案,同時保留了 NoSQL 的靈活性和擴展性。…

WebSocket 前端斷連原因與檢測方法

文章目錄 前言WebSocket 前端斷連原因與檢測方法常見 WebSocket 斷連原因及檢測方式聊天系統場景下的斷連問題與影響行情推送場景下的斷連問題與影響React 前端應對斷連的穩健策略自動重連機制的設計與節流控制心跳機制的實現與保持連接存活連接狀態管理與 React 集成錯誤提示與…

2025年真實面試問題匯總(三)

線上數據庫數據丟失如何恢復 線上數據庫數據丟失的恢復方法需要根據數據丟失原因、備份情況及數據庫類型(如MySQL、SQL Server、PostgreSQL等)綜合處理,以下是通用的分步指南: 一、緊急止損:暫停寫入,防止…

Android音視頻多媒體開源框架基礎大全

安卓多媒體開發框架中,從音頻采集,視頻采集,到音視頻處理,音視頻播放顯示分別有哪些常用的框架?分成六章,這里一次幫你總結完。 音視頻的主要流程是錄制、處理、編解碼和播放顯示。本文也遵循這個流程展開…

安卓上架華為應用市場、應用寶、iosAppStore上架流程,保姆級記錄(1)

上架前請準備好apk、備案、軟著、企業開發者賬號!!!其余準備好app相關的截圖、介紹、測試賬號,沒講解明白的評論區留言~ 華為應用市場 1、登錄賬號 打開 華為開發者平臺 https://developer.huawei.com/consumer/cn/ 2.登錄企…

【Docker】docker 常用命令

目錄 一、鏡像管理 二、容器操作 三、網絡管理 四、存儲卷管理 五、系統管理 六、Docker Compose 常用命令 一、鏡像管理 命令參數解說示例說明docker pull鏡像名:標簽docker pull nginx:alpine拉取鏡像(默認從 Docker Hub)docker images-a&#x…

OSPF域內路由

簡介 Router-LSA Router-LSA(Router Link State Advertisement)是OSPF(Open Shortest Path First)協議中的一種鏈路狀態通告(LSA),它由OSPF路由器生成,用于描述路由器自身的鏈路狀態…

torch 高維矩陣乘法分析,一文說透

文章目錄 簡介向量乘法二維矩陣乘法三維矩陣乘法廣播 高維矩陣乘法開源 簡介 一提到矩陣乘法,大家對于二維矩陣乘法都很了解,即 A 矩陣的行乘以 B 矩陣的列。 但對于高維矩陣乘法可能就不太清楚,不知道高維矩陣乘法是怎么在計算。 建議使用…

瑞薩RA-T系列芯片馬達類工程TCM加速化設置

本篇介紹在使用RA8-T系列芯片,建立馬達類工程應用時,如何將電流環部分的指令和變量設置到TCM單元,以提高電流環執行速度,從而提高系統整體的運行性能,在伺服和高端工業領域有很高的實用價值。本文以RA8T1為范例&#x…

獲取Unity節點路徑

解決目的: 避免手動拼寫節點路徑的時候,出現路徑錯誤導致獲取不到節點的情況。解決效果: 添加如下腳本之后,將自動復制路徑到剪貼板中,在代碼中通過 ctrlv 粘貼路徑代碼如下: public class CustomMenuItems…

Docker 安裝 Oracle 12C

鏡像 https://docker.aityp.com/image/docker.io/truevoly/oracle-12c:latest docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest docker tag swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest d…

Linux內核網絡協議注冊與初始化:從proto_register到tcp_v4_init_sock的深度解析

一、協議注冊:proto_register的核心使命 在Linux網絡協議棧中,proto_register是協議初始化的基石,主要完成三項關鍵任務: Slab緩存創建(內存管理核心) prot->slab = kmem_cache_create_usercopy(prot->name, prot->obj_size, ...); if (prot->twsk_prot) pr…

GD32 MCU的真隨機數發生器(TRNG)

GD32 MCU的真隨機數發生器(TRNG) 文章目錄 GD32 MCU的真隨機數發生器(TRNG)一、定義與核心特征二、物理機制:量子與經典隨機性三、生成方法四、應用場景五、與偽隨機數的對比六、局限性?? 七、物理熵源原理?? 八、硬件實現流程(以GD32F450 GD32L233為例)8.1. **初始…

Vulkan學習筆記6—渲染呈現

一、渲染循環核心 while (!glfwWindowShouldClose(window)) {glfwPollEvents();helloTriangleApp.drawFrame(); // 繪制幀} 在 Vulkan 中渲染幀包含一組常見的步驟 等待前一幀完成(vkWaitForFences) 從交換鏈獲取圖像(vkAcquireNextImageKH…

React第六十二節 Router中 createStaticRouter 的使用詳解

前言 createStaticRouter 是 React Router 專為 服務端渲染(SSR) 設計的 API,用于在服務器端處理路由匹配和數據加載。它在構建靜態 HTML 響應時替代了客戶端的 BrowserRouter,確保 SSR 和客戶端 Hydration 的路由狀態一致。 一…

qt 雙緩沖案例對比

雙緩沖 1.雙緩沖原理 單緩沖:在paintEvent中直接繪制到屏幕,繪制過程被用戶看到 雙緩沖:先在redrawBuffer繪制到緩沖區,然后一次性顯示完整結果 代碼結構 單緩沖:所有繪制邏輯在paintEvent中 雙緩沖:繪制…

華為云AI開發平臺ModelArts

華為云ModelArts:重塑AI開發流程的“智能引擎”與“創新加速器”! 在人工智能浪潮席卷全球的2025年,企業擁抱AI的意愿空前高漲,但技術門檻高、流程復雜、資源投入巨大的現實,卻讓許多創新構想止步于實驗室。數據科學家…