GPT - GPT(Generative Pre-trained Transformer)模型框架

本節代碼主要為實現了一個簡化版的 GPT(Generative Pre-trained Transformer)模型。GPT 是一種基于 Transformer 架構的語言生成模型,主要用于生成自然語言文本。
?

1.?模型結構

初始化部分
class GPT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, dff, dropout):super().__init__()self.emb = nn.Embedding(vocab_size, d_model)self.pos = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerDecoderBlock(d_model, dff, dropout)for i in range(N_blocks)])self.fc = nn.Linear(d_model, vocab_size)
  • vocab_size:詞匯表的大小,即模型可以處理的唯一詞元(token)的數量。

  • d_model:模型的維度,表示嵌入和內部表示的維度。

  • seq_len:序列的最大長度,即輸入序列的最大長度。

  • N_blocks:Transformer 解碼器塊的數量。

  • dff:前饋網絡(Feed-Forward Network, FFN)的維度。

  • dropout:Dropout 的概率,用于防止過擬合。

組件說明
  1. self.emb:詞嵌入層,將輸入的詞元索引映射到 d_model 維的向量空間。

  2. self.pos:位置嵌入層,將序列中每個位置的索引映射到 d_model 維的向量空間。位置嵌入用于給模型提供序列中每個詞元的位置信息。

  3. self.layers:一個模塊列表,包含 N_blocksTransformerDecoderBlock。每個塊是一個 Transformer 解碼器層,包含多頭注意力機制和前饋網絡。

  4. self.fc:一個線性層,將解碼器的輸出映射到詞匯表大小的維度,用于生成最終的詞元概率分布。

2.?前向傳播

def forward(self, x, attn_mask=None):emb = self.emb(x)pos = self.pos(torch.arange(x.shape[1]))x = emb + posfor layer in self.layers:x = layer(x, attn_mask)return self.fc(x)
步驟解析
  1. 詞嵌入和位置嵌入

    • self.emb(x):將輸入的詞元索引 x 轉換為詞嵌入表示 emb,形狀為 (batch_size, seq_len, d_model)

    • self.pos(torch.arange(x.shape[1])):生成位置嵌入 pos,形狀為 (seq_len, d_model)torch.arange(x.shape[1]) 生成一個從 0 到 seq_len-1 的序列,表示每個位置的索引。

    • x = emb + pos:將詞嵌入和位置嵌入相加,得到最終的輸入表示 x。位置嵌入的加入使得模型能夠區分序列中不同位置的詞元。

  2. Transformer 解碼器層

    • for layer in self.layers:將輸入 x 逐層傳遞給每個 TransformerDecoderBlock

    • x = layer(x, attn_mask):每個解碼器塊會處理輸入 x,并應用因果掩碼 attn_mask(如果提供)。因果掩碼確保模型在解碼時只能看到當前及之前的位置,而不能看到未來的信息。

  3. 輸出層

    • self.fc(x):將解碼器的輸出 x 傳遞給線性層 self.fc,生成最終的輸出。輸出的形狀為 (batch_size, seq_len, vocab_size),表示每個位置上每個詞元的預測概率。

截止到本篇文章GPT簡單復現完成,下面將附完整代碼,方便理解代碼整體結構

import math
import torch
import random
import torch.nn as nnfrom tqdm import tqdm
from torch.utils.data import Dataset, DataLoader'''
仿 nn.TransformerDecoderLayer 實現
'''class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_project = nn.Linear(d_model, d_model)self.k_project = nn.Linear(d_model, d_model)self.v_project = nn.Linear(d_model, d_model)self.o_project = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, attn_mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)if attn_mask is not None:attn_mask = attn_mask.unsqueeze(1)atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)out = self.o_project(out)return self.dropout(out)class TransformerDecoderBlock(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.linear1 = nn.Linear(d_model, dff)self.activation = nn.GELU()# self.activation = nn.ReLU()self.dropout = nn .Dropout(dropout)self.linear2 = nn.Linear(dff, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)self.mha_block1 = MultiHeadAttention(d_model, num_heads, dropout)self.mha_block2 = MultiHeadAttention(d_model, num_heads, dropout)def forward(self, x, mask=None):x = self.norm1(x + self.dropout1(self.mha_block1(x, mask)))x = self.norm2(x + self.dropout2(self.mha_block2(x, mask)))x = self.norm3(self.linear2(self.dropout(self.activation(self.linear1(x)))))return xclass GPT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, dff, dropout):super().__init__()self.emb = nn.Embedding(vocab_size, d_model)self.pos = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerDecoderBlock(d_model, dff, dropout)for i in range(N_blocks)])self.fc = nn.Linear(d_model, vocab_size)def forward(self, x, attn_mask=None):emb = self.emb(x)pos = self.pos(torch.arange(x.shape[1]))x = emb + posfor layer in self.layers:x = layer(x, attn_mask)return self.fc(x)def read_data(file, num=1000):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().split("\n")res = [line[:24] for line in data[:num]]return resdef tokenize(corpus):vocab = {"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3, ",": 4, "。": 5, "?": 6}for line in corpus:for token in line:vocab.setdefault(token, len(vocab))idx2word = list(vocab)return vocab, idx2wordclass Tokenizer:def __init__(self, vocab, idx2word):self.vocab = vocabself.idx2word = idx2worddef encode(self, text):ids = [self.token2id(token) for token in text]return idsdef decode(self, ids):tokens = [self.id2token(id) for id in ids]return tokensdef id2token(self, id):token = self.idx2word[id]return tokendef token2id(self, token):id = self.vocab.get(token, self.vocab["[UNK]"])return idclass Poetry(Dataset):def __init__(self, poetries, tokenizer: Tokenizer):self.poetries = poetriesself.tokenizer = tokenizerself.pad_id = self.tokenizer.vocab["[PAD]"]self.bos_id = self.tokenizer.vocab["[BOS]"]self.eos_id = self.tokenizer.vocab["[EOS]"]def __len__(self):return len(self.poetries)def __getitem__(self, idx):poetry = self.poetries[idx]poetry_ids = self.tokenizer.encode(poetry)input_ids = torch.tensor([self.bos_id] + poetry_ids)input_msk = causal_mask(input_ids)label_ids = torch.tensor(poetry_ids + [self.eos_id])return {"input_ids": input_ids,"input_msk": input_msk,"label_ids": label_ids}def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return maskdef generate_poetry(method="greedy", top_k=5):model.eval()with torch.no_grad():input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)while input_ids.shape[1] < seq_len:output = model(input_ids, None)probabilities = torch.softmax(output[:, -1, :], dim=-1)if method == "greedy":next_token_id = torch.argmax(probabilities, dim=-1)elif method == "top_k":top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]if next_token_id == vocab["[EOS]"]:breakinput_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)return input_ids.squeeze()if __name__ == "__main__":file = "/Users/azen/Desktop/llm/LLM-FullTime/dataset/text-generation/poetry_data.txt"poetries = read_data(file, num=2000)vocab, idx2word = tokenize(poetries)tokenizer = Tokenizer(vocab, idx2word)trainset = Poetry(poetries, tokenizer)batch_size = 16trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)d_model = 512seq_len = 25 # 有特殊標記符num_heads = 8dropout = 0.1dff = 4*d_modelN_blocks = 2model = GPT(len(vocab), d_model, seq_len, N_blocks, dff, dropout)lr = 1e-4optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 100for epoch in range(epochs):for batch in tqdm(trainloader, desc="Training"):batch_input_ids = batch["input_ids"]batch_input_msk = batch["input_msk"]batch_label_ids = batch["label_ids"]output = model(batch_input_ids, batch_input_msk)loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))loss.backward()optim.step()optim.zero_grad()print("Epoch: {}, Loss: {}".format(epoch, loss))res = generate_poetry(method="top_k")text = tokenizer.decode(res)print("".join(text))pass

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75511.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75511.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75511.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于FPGA的六層電梯智能控制系統 矩陣鍵盤-數碼管 上板仿真均驗證通過

基于FPGA的六層電梯智能控制系統 前言一、整體方案二、軟件設計總結 前言 本設計基于FPGA實現了一個完整的六層電梯智能控制系統&#xff0c;旨在解決傳統電梯控制系統在別墅環境中存在的個性化控制不足、響應速度慢等問題。系統采用Verilog HDL語言編程&#xff0c;基于Cyclo…

車載通信系統中基于ISO26262的功能安全與抗輻照協同設計研究

摘要&#xff1a;隨著智能網聯汽車的快速發展&#xff0c;車載通信系統正面臨著功能安全與抗輻照設計的雙重挑戰。在高可靠性要求的車載應用場景下&#xff0c;如何實現功能安全標準與抗輻照技術的協同優化&#xff0c;構建滿足ISO26262安全完整性等級要求的可靠通信架構&#…

Node.js種cluster模塊詳解

Node.js 中 cluster 模塊全部 API 詳解 1. 模塊屬性 const cluster require(cluster);// 1. isMaster // 判斷當前進程是否為主進程 console.log(是否為主進程:, cluster.isMaster);// 2. isWorker // 判斷當前進程是否為工作進程 console.log(是否為工作進程:, cluster.isW…

融合動態權重與抗刷機制的網文評分系統——基于優書網、IMDB與Reddit的混合算法實踐

? Yumuing 博客 &#x1f680; 探索技術的每一個角落&#xff0c;解碼世界的每一種可能&#xff01; &#x1f48c; 如果你對 AI 充滿好奇&#xff0c;歡迎關注博主&#xff0c;訂閱專欄&#xff0c;讓我們一起開啟這段奇妙的旅程&#xff01; 以權威用戶為核心&#xff0c;時…

使用Golang打包jar應用

文章目錄 背景Go 的 go:embed 功能介紹與打包 JAR 文件示例1. go:embed 基礎介紹基本特性基本語法 2. 嵌入 JAR 文件示例項目結構代碼實現 3. 高級用法&#xff1a;嵌入多個文件或目錄4. 使用注意事項5. 實際應用場景6. 完整示例&#xff1a;運行嵌入的JAR 背景 想把自己的一個…

前端大屏可視化項目 局部全屏(指定盒子全屏)

需求是這樣的&#xff0c;我用的項目是vue admin 項目 現在需要在做大屏項目 不希望顯示除了大屏的其他東西 于是想了這個辦法 至于大屏適配問題 請看我文章 底部的代碼直接復制就可以運行 vue2 px轉rem 大屏適配方案 postcss-pxtorem-CSDN博客 <template><div …

《2025藍橋杯C++B組:D:產值調整》

**作者的個人gitee**?? 作者的算法講解主頁?? 每日一言&#xff1a;“淚眼問花花不語&#xff0c;亂紅飛過秋千去&#x1f338;&#x1f338;” 題目 二.解題策略 本題比較簡單&#xff0c;我的思路是寫三個函數分別計算黃金白銀銅一次新產值&#xff0c;通過k次循環即可獲…

[VTK] 四元素實現旋轉平移

VTK 實現旋轉&#xff0c;有四元數的方案&#xff0c;也有 vtkTransform 的方案&#xff1b;主要示例代碼如下&#xff1a; //構造旋轉四元數vtkQuaterniond rotation;rotation.SetRotationAngleAndAxis(vtkMath::RadiansFromDegrees(90.0),0.0, 1.0, 0.0);//構造旋轉點四元數v…

華為hcie證書的有效期怎么判斷?

在ICT行業&#xff0c;華為HCIE證書堪稱含金量極高的“敲門磚”&#xff0c;擁有它往往意味著在職場上更上一層樓。然而&#xff0c;很多人在辛苦考取HCIE證書后&#xff0c;卻對其有效期相關事宜一知半解。今天&#xff0c;咱們就來好好嘮嘮華為HCIE證書的有效期怎么判斷這個關…

【精品PPT】2025固態電池知識體系及最佳實踐PPT合集(36份).zip

精品推薦&#xff0c;2025固態電池知識體系及最佳實踐PPT合集&#xff0c;共36份。供大家學習參考。 1、中科院化學所郭玉國研究員&#xff1a;固態金屬鋰電池及其關鍵材料.pdf 2、中科院物理所-李泓固態電池.pdf 3、全固態電池技術研究進展.pdf 4、全固態電池生產工藝.pdf 5、…

MySQL 中為產品添加靈活的自定義屬性(如 color/size)

方案 1&#xff1a;EAV 模型&#xff08;最靈活但較復雜&#xff09; 適合需要無限擴展自定義屬性的場景 -- 產品表 CREATE TABLE products (id INT PRIMARY KEY AUTO_INCREMENT,name VARCHAR(100),price DECIMAL(10,2) );-- 屬性名表 CREATE TABLE attributes (id INT PRIMA…

CSPM認證對項目論證的范式革新:從合規審查到價值創造的戰略躍遷

引言 在數字化轉型浪潮中&#xff0c;全球企業每年因項目論證缺陷導致的損失高達1.7萬億美元&#xff08;Gartner 2023&#xff09;。CSPM&#xff08;Certified Strategic Project Manager&#xff09;認證體系通過結構化方法論&#xff0c;將傳統的項目可行性評估升級為戰略…

CLIP中的Zero-Shot Learning原理

CLIP&#xff08;Contrastive Language-Image Pretraining&#xff09;是一種由OpenAI提出的多模態模型&#xff0c;它通過對比學習的方式同時學習圖像和文本的表示&#xff0c;并且能在多種任務中進行零樣本學習&#xff08;Zero-Shot Learning&#xff09;。CLIP模型的核心創…

spring mvc 中 RestTemplate 全面詳解及示例

RestTemplate 全面詳解及示例 1. RestTemplate 簡介 定義&#xff1a;Spring 提供的同步 HTTP 客戶端&#xff0c;支持多種 HTTP 方法&#xff08;GET/POST/PUT/DELETE 等&#xff09;&#xff0c;用于調用 RESTful API。核心特性&#xff1a; 支持請求頭、請求體、URI 參數的…

北大:LLM在NL2SQL中任務分解

&#x1f4d6;標題&#xff1a;LearNAT: Learning NL2SQL with AST-guided Task Decomposition for Large Language Models &#x1f310;來源&#xff1a;arXiv, 2504.02327 &#x1f31f;摘要 &#x1f538;自然語言到SQL&#xff08;NL2SQL&#xff09;已成為實現與數據庫…

STM32LL庫編程系列第八講——ADC模數轉換

系列文章目錄 往期文章 STM32LL庫編程系列第一講——Delay精準延時函數&#xff08;詳細&#xff0c;適合新手&#xff09; STM32LL庫編程系列第二講——藍牙USART串口通信&#xff08;步驟詳細、原理清晰&#xff09; STM32LL庫編程系列第三講——USARTDMA通信 STM32LL庫編程…

網絡5 TCP/IP 虛擬機橋接模式、NAT、僅主機模式

TCP/IP模型 用于局域網和廣域網&#xff1b;多個協議&#xff1b;每一層呼叫下一層&#xff1b;四層&#xff1b;通用標準 TCP/IP模型 OSI七層模型 應用層 應用層 表示層 會話層 傳輸層 傳輸層 網絡層 網絡層 鏈路層 數據鏈路層 物理層 鏈路層&#xff1a;傳數據幀&#xff0…

【C語言】預處理(下)(C語言完結篇)

一、#和## 1、#運算符 這里的#是一個運算符&#xff0c;整個運算符會將宏的參數轉換為字符串字面量&#xff0c;它僅可以出現在帶參數的宏的替換列表中&#xff0c;我們可以將其理解為字符串化。 我們先看下面的一段代碼&#xff1a; 第二個printf中是由兩個字符串組成的&am…

【高性能緩存Redis_中間件】一、快速上手redis緩存中間件

一、鋪墊 在當今的軟件開發領域&#xff0c;消息隊列扮演著至關重要的角色。它能夠幫助我們實現系統的異步處理、流量削峰以及系統解耦等功能&#xff0c;從而提升系統的性能和可維護性。Redis 作為一款高性能的鍵值對數據庫&#xff0c;不僅提供了豐富的數據結構&#xff0c;…

Java如何獲取文件的編碼格式?

Java獲取文件的編碼格式 在計算機中&#xff0c;文件編碼是指將文件內容轉換成二進制形式以便存儲和傳輸的過程。常見的文件編碼格式包括UTF-8、GBK等。不同的編碼使用不同的字符集和字節序列&#xff0c;因此在讀取文件時需要正確地確定文件的編碼格式 Java提供了多種方式以獲…