BERT - Bert模型框架復現

本節將實現一個基于Transformer架構的BERT模型。

1.?MultiHeadAttention 類

這個類實現了多頭自注意力機制(Multi-Head Self-Attention),是Transformer架構的核心部分。

在前幾篇文章中均有講解,直接上代碼

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.o_proj = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:mask = mask.unsqueeze(1).unsqueeze(1)atten_scores = atten_scores.masked_fill(mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)return self.dropout(self.o_proj(out))

2.?FeedForward 類

這個類實現了Transformer中的前饋網絡(Feed-Forward Network, FFN)。

在前幾篇文章中均有講解,直接上代碼

class FeedForward(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.W1 = nn.Linear(d_model, dff)self.act = nn.GELU()self.W2 = nn.Linear(dff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.W2(self.dropout(self.act(self.W1(x))))

3.?TransformerEncoderBlock 類

這個類實現了Transformer架構中的一個編碼器塊(Encoder Block)。

在前幾篇文章中有Decoder的講解(與Encoder原理基本相似),直接上代碼

class TransformerEncoderBlock(nn.Module):def __init__(self, d_model, num_heads, dropout, dff):super().__init__()self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)self.ffn_block = FeedForward(d_model, dff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, mask=None):res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))return res2

4.?BertModel 類

這個類實現了BERT模型的整體架構。

class BertModel(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.seg_emb = nn.Embedding(3, d_model)self.pos_emb = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, dropout, dff)for _ in range(N_blocks)])self.norm = nn.LayerNorm(d_model)self.drop = nn.Dropout(dropout)def forward(self, x, seg_ids, mask):pos = torch.arange(x.shape[1])tok_emb = self.tok_emb(x)seg_emb = self.seg_emb(seg_ids)pos_emb = self.pos_emb(pos)x = tok_emb + seg_emb + pos_embfor layer in self.layers:x = layer(x, mask)x = self.norm(x)return x
  • 詞嵌入、段嵌入和位置嵌入

    • tok_emb:將輸入的詞索引映射到詞嵌入空間。

    • seg_emb:用于區分不同的句子(例如在BERT中,用于區分句子A和句子B)。

    • pos_emb:將位置信息編碼到嵌入空間,使模型能夠捕捉到序列中的位置信息。

  • Transformer編碼器層:通過nn.ModuleList堆疊了N_blocksTransformerEncoderBlock,每個塊都負責對輸入序列進行進一步的特征提取。

  • 層歸一化和Dropout:在所有編碼器層處理完畢后,對輸出進行層歸一化和Dropout處理,進一步穩定模型的輸出。

Bert完整代碼(標紅部分為本節所提到部分)

import re
import math
import torch
import random
import torch.nn as nnfrom transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader# nn.TransformerEncoderLayerclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.o_proj = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:mask = mask.unsqueeze(1).unsqueeze(1)atten_scores = atten_scores.masked_fill(mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)return self.dropout(self.o_proj(out))class FeedForward(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.W1 = nn.Linear(d_model, dff)self.act = nn.GELU()self.W2 = nn.Linear(dff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.W2(self.dropout(self.act(self.W1(x))))class TransformerEncoderBlock(nn.Module):def __init__(self, d_model, num_heads, dropout, dff):super().__init__()self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)self.ffn_block = FeedForward(d_model, dff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, mask=None):res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))return res2class BertModel(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.seg_emb = nn.Embedding(3, d_model)self.pos_emb = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, dropout, dff)for _ in range(N_blocks)])self.norm = nn.LayerNorm(d_model)self.drop = nn.Dropout(dropout)def forward(self, x, seg_ids, mask):pos = torch.arange(x.shape[1])tok_emb = self.tok_emb(x)seg_emb = self.seg_emb(seg_ids)pos_emb = self.pos_emb(pos)x = tok_emb + seg_emb + pos_embfor layer in self.layers:x = layer(x, mask)x = self.norm(x)return xclass BERT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.bert = BertModel(vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff)self.mlm_head = nn.Linear(d_model, vocab_size)self.nsp_head = nn.Linear(d_model, 2)def forward(self, mlm_tok_ids, seg_ids, mask):bert_out = self.bert(mlm_tok_ids, seg_ids, mask)cls_token = bert_out[:, 0, :]mlm_logits = self.mlm_head(bert_out)nsp_logits = self.nsp_head(cls_token)return mlm_logits, nsp_logitsdef read_data(file):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().replace("\n", "")corpus = re.split(r'[。,“”:;!、]', data)corpus = [sentence for sentence in corpus if sentence.strip()]return corpusdef create_nsp_dataset(corpus):nsp_dataset = []for i in range(len(corpus)-1):next_sentence = corpus[i+1]rand_id = random.randint(0, len(corpus) - 1)while abs(rand_id - i) <= 1:rand_id = random.randint(0, len(corpus) - 1)negt_sentence = corpus[rand_id]nsp_dataset.append((corpus[i], next_sentence, 1)) # 正樣本nsp_dataset.append((corpus[i], negt_sentence, 0)) # 負樣本return nsp_datasetclass BERTDataset(Dataset):def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):self.nsp_dataset = nsp_datasetself.tokenizer = tokenizerself.max_length = max_lengthself.cls_id = tokenizer.cls_token_idself.sep_id = tokenizer.sep_token_idself.pad_id = tokenizer.pad_token_idself.mask_id = tokenizer.mask_token_iddef __len__(self):return len(self.nsp_dataset)def __getitem__(self, idx):sent1, sent2, nsp_label = self.nsp_dataset[idx]sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)seg_ids = self.pad_to_seq_len(seg_ids, 2)mlm_labels = self.pad_to_seq_len(mlm_labels, -100)mask = (mlm_tok_ids != 0)return {"mlm_tok_ids": mlm_tok_ids,"seg_ids": seg_ids,"mask": mask,"mlm_labels": mlm_labels,"nsp_labels": torch.tensor(nsp_label)}def pad_to_seq_len(self, seq, pad_value):seq = seq[:self.max_length]pad_num = self.max_length - len(seq)return torch.tensor(seq + pad_num * [pad_value])def build_mlm_dataset(self, tok_ids):mlm_tok_ids = tok_ids.copy()mlm_labels = [-100] * len(tok_ids)for i in range(len(tok_ids)):if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:if random.random() < 0.15:mlm_labels[i] = tok_ids[i]if random.random() < 0.8:mlm_tok_ids[i] = self.mask_idelif random.random() < 0.9:mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)return mlm_tok_ids, mlm_labelsif __name__ == "__main__":data_file = "4.10-BERT/背影.txt"model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"tokenizer = BertTokenizer.from_pretrained(model_path)corpus = read_data(data_file)max_length = 25 # len(max(corpus, key=len))print("Max length of dataset: {}".format(max_length))nsp_dataset = create_nsp_dataset(corpus)trainset = BERTDataset(nsp_dataset, tokenizer, max_length)batch_size = 16trainloader = DataLoader(trainset, batch_size, shuffle=True)vocab_size = tokenizer.vocab_sized_model = 768N_blocks = 2num_heads = 12dropout = 0.1dff = 4*d_modelmodel = BERT(vocab_size, d_model, max_length, N_blocks, num_heads, dropout, dff)lr = 1e-3optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 20for epoch in range(epochs):for batch in trainloader:batch_mlm_tok_ids = batch["mlm_tok_ids"]batch_seg_ids = batch["seg_ids"]batch_mask = batch["mask"]batch_mlm_labels = batch["mlm_labels"]batch_nsp_labels = batch["nsp_labels"]mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)loss = loss_mlm + loss_nsploss.backward()optim.step()optim.zero_grad()print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))passpass

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/79012.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/79012.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/79012.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

解決 Spring Boot 啟動報錯:數據源配置引發的啟動失敗

啟動項目時&#xff0c;控制臺輸出了如下錯誤信息&#xff1a; Error starting ApplicationContext. To display the condition evaluation report re-run your application with debug enabled. 2025-04-14 21:13:33.005 [main] ERROR o.s.b.d.LoggingFailureAnalysisReporte…

履帶小車+六軸機械臂(2)

本次介紹原理圖部分 開發板部分&#xff0c;電源供電部分&#xff0c;六路舵機&#xff0c;PS2手柄接收器&#xff0c;HC-05藍牙模塊&#xff0c;蜂鳴器&#xff0c;串口&#xff0c;TB6612電機驅動模塊&#xff0c;LDO線性穩壓電路&#xff0c;按鍵部分 1、開發板部分 需要注…

【開發記錄】服務外包大賽記錄

參加服務外包大賽的A07賽道中&#xff0c;最近因為頻繁的DEBUG&#xff0c;心態爆炸 記錄錯誤 以防止再次出現錯誤浪費時間。。。 2025.4.13 項目在上傳圖片之后 會自動刷新 沒有等待后端返回 Network中的fetch /upload顯示canceled. 然而這是使用了VS的live Server插件才這樣&…

基于FreeRTOS和LVGL的多功能低功耗智能手表(硬件篇)

目錄 一、簡介 二、板子構成 三、核心板 3.1 MCU最小系統板電路 3.2 電源電路 3.3 LCD電路 3.4 EEPROM電路 3.5 硬件看門狗電路 四、背板 4.1 傳感器電路 4.2 充電盤 4.3 藍牙模塊電路 五、總結 一、簡介 本篇開始介紹這個項目的硬件部分&#xff0c;從最小電路設…

為 Kubernetes 提供智能的 LLM 推理路由:Gateway API Inference Extension 深度解析

現代生成式 AI 和大語言模型&#xff08;LLM&#xff09;服務給 Kubernetes 帶來了獨特的流量路由挑戰。與典型的短時、無狀態 Web 請求不同&#xff0c;LLM 推理會話通常是長時運行、資源密集且部分有狀態的。例如&#xff0c;一個基于 GPU 的模型服務器可能同時維護多個活躍的…

MacOs下解決遠程終端內容復制并到本地粘貼板

常常需要在服務器上搗鼓東西&#xff0c;同時需要將內容復制到本地的需求。 1-內容是在遠程終端用vim打開&#xff0c;如何用vim的類似指令達到快速復制到本地呢&#xff1f; 假設待復制的內容&#xff1a; #include <iostream> #include <cstring> using names…

STM32 vs ESP32:如何選擇最適合你的單片機?

引言 在嵌入式開發中&#xff0c;STM32 和 ESP32 是兩種最熱門的微控制器方案。但許多開發者面對項目選型時仍會感到困惑&#xff1a;到底是選擇功能強大的 STM32&#xff0c;還是集成無線的 ESP32&#xff1f; 本文將通過 硬件資源、開發場景、成本分析 等多維度對比&#xf…

【blender小技巧】Blender導出帶貼圖的FBX模型,并在unity中提取材質模型使用

前言 這其實是我之前做過的操作&#xff0c;我只是單獨提取出來了而已。感興趣可以去看看&#xff1a;【blender小技巧】使用Blender將VRM或者其他模型轉化為FBX模型&#xff0c;并在unity使用&#xff0c;導出帶貼圖的FBX模型&#xff0c;貼圖材質問題修復 一、導出帶貼圖的…

如何保證本地緩存和redis的一致性

1. Cache Aside Pattern&#xff08;旁路緩存模式&#xff09;?? ?核心思想?&#xff1a;應用代碼直接管理緩存與數據的同步&#xff0c;分為讀寫兩個流程&#xff1a; ?讀取數據?&#xff1a; 先查本地緩存&#xff08;如 Guava Cache&#xff09;。若本地未命中&…

k8s通過service標簽實現藍綠發布

k8s通過service標簽實現藍綠發布 通過k8s service label標簽實現藍綠發布方法1:使用kubelet完成藍綠切換1. 創建綠色版本1.1 創建綠色版本 Deployment1.2 創建綠色版本 Service 2. 創建藍色版本2.1 創建藍色版本 Deployment2.2 創建藍色版本 Service 3. 創建藍綠切換SVC (用于外…

智慧酒店企業站官網-前端靜態網站模板【前端練習項目】

最近又寫了一個靜態網站&#xff0c;智慧酒店宣傳官網。 使用的技術 html css js 。 特別適合編程學習者進行網頁制作和前端開發的實踐。 項目包含七個核心模塊&#xff1a;首頁、整體解決方案、優勢、全國案例、行業觀點、合作加盟、關于我們。 通過該項目&#xff0c;小伙伴們…

Epplus 8+ 許可證設置

Epplus 8 之后非商業許可證的設置變了如果還用普通的方法會報錯 Unhandled exception. OfficeOpenXml.LicenseContextPropertyObsoleteException: Please use the static ‘ExcelPackage.License’ property to set the required license information from EPPlus 8 and later …

CST1016.基于Spring Boot+Vue高校競賽管理系統

計算機/JAVA畢業設計 【CST1016.基于Spring BootVue高校競賽管理系統】 【項目介紹】 高校競賽管理系統&#xff0c;基于 DeepSeek Spring AI Spring Boot Vue 實現&#xff0c;功能豐富、界面精美 【業務模塊】 系統共有兩類用戶&#xff0c;分別是學生用戶和管理員用戶&a…

2025年第十六屆藍橋杯省賽C++ 研究生組真題

2025年第十六屆藍橋杯省賽C 研究生組真題 1.說明2.題目A&#xff1a;數位倍數&#xff08;5分&#xff09;3.題目B&#xff1a;IPv6&#xff08;5分&#xff09;4.題目C&#xff1a;變換數組&#xff08;10分&#xff09;5.題目D&#xff1a;最大數字&#xff08;10分&#xff…

空調開機啟動后發出噼里啪啦的異響分析與解決

背景 當空調使用時由于制冷或制熱運轉時&#xff08;關機后可能也會出現&#xff09;&#xff0c;塑料件熱脹冷縮引起&#xff0c;可能會出現“咔咔”的聲音&#xff1b;空調冷媒在空調內管路流動時會出現輕微的“沙沙”的聲音&#xff1b;也有可能是新裝的空調擺風軸出現響聲…

BERT、T5、ViT 和 GPT-3 架構概述及代表性應用

BERT、T5、ViT 和 GPT-3 架構概述 1. BERT&#xff08;Bidirectional Encoder Representations from Transformers&#xff09; 架構特點 基于 Transformer 編碼器&#xff1a;BERT 使用多層雙向 Transformer 編碼器&#xff0c;能夠同時捕捉輸入序列中每個詞的左右上下文信息…

選導師原理

總述 一句話總結&#xff1a;是雷一定要避&#xff0c;好的一定要搶。方向契合最好&#xff0c;不契合適當取舍。 首先明確自身需求&#xff1a; 我要學東西&#xff01;青年導師&#xff0c;好溝通&#xff0c;有沖勁&#xff0c;高壓力。 我要擺爛&#xff01;中老年男性教…

【過程控制系統】PID算式實現,控制系統分類,工程應用中控制系統應該注意的問題

目錄 1-1 試簡述過程控制的發展概況及各個階段的主要特點。 1-2 與其它自動控制相比&#xff0c;過程控制有哪些優點&#xff1f;為什么說過程控制的控制過程多屬慢過程&#xff1f; 1-3 什么是過程控制系統&#xff0c;其基本分類是什么&#xff1f; 1-4 何為集散控制系統…

2025年第十六屆藍橋杯省賽真題解析 Java B組(簡單經驗分享)

之前一年拿了國二后&#xff0c;基本就沒刷過題了&#xff0c;實力掉了好多&#xff0c;這次參賽只是為了學校的加分水水而已&#xff0c;希望能拿個省三吧 >_< 目錄 1. 逃離高塔思路代碼 2. 消失的藍寶思路代碼 3. 電池分組思路代碼 4. 魔法科考試思路代碼 5. 爆破思路…

JAVA EE_文件操作和IO

人們大多數時候的焦慮&#xff0c;大概是太想要一個那不確定的答案了吧。 一一 陳長生. 1.認識文件 現實中&#xff0c;我們把一張一張有內容的紙整合在一起稱為文件&#xff0c;計算機中&#xff0c;通過硬盤這種I/O設備進行數據保存時&#xff0c;它會獨立成一個一個的單位保…