BERT - MLM 和 NSP

本節代碼將實現BERT模型的兩個主要預訓練任務:掩碼語言模型(Masked Language Model, MLM)下一句預測(Next Sentence Prediction, NSP)

1.?create_nsp_dataset?函數

這個函數用于生成NSP任務的數據集。

def create_nsp_dataset(corpus):nsp_dataset = []for i in range(len(corpus)-1):next_sentence = corpus[i+1]rand_id = random.randint(0, len(corpus) - 1)while abs(rand_id - i) <= 1:rand_id = random.randint(0, len(corpus) - 1)negt_sentence = corpus[rand_id]nsp_dataset.append((corpus[i], next_sentence, 1))  # 正樣本nsp_dataset.append((corpus[i], negt_sentence, 0))  # 負樣本return nsp_dataset
  • 正樣本corpus[i]corpus[i+1] 是連續的句子對,標記為 1,表示它們是相鄰的句子。

  • 負樣本corpus[i] 和隨機選擇的句子 corpus[rand_id] 組成一個句子對,標記為 0,表示它們不是相鄰的句子。

  • 隨機選擇負樣本:通過隨機選擇句子來生成負樣本,確保模型能夠學習區分相鄰句子和非相鄰句子。

2.?BERTDataset?類

這個類繼承自 torch.utils.data.Dataset,用于加載和處理BERT預訓練任務的數據。

def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):self.nsp_dataset = nsp_datasetself.tokenizer = tokenizerself.max_length = max_lengthself.cls_id = tokenizer.cls_token_idself.sep_id = tokenizer.sep_token_idself.pad_id = tokenizer.pad_token_idself.mask_id = tokenizer.mask_token_id
  • nsp_dataset:存儲NSP任務的數據集,每個樣本是一個三元組 (sent1, sent2, nsp_label)

  • tokenizer:用于將文本轉換為詞索引(token IDs)。

  • max_length:序列的最大長度,用于填充或截斷。

  • 特殊標記

    • self.cls_id[CLS] 標記的索引。

    • self.sep_id[SEP] 標記的索引。

    • self.pad_id[PAD] 標記的索引。

    • self.mask_id[MASK] 標記的索引。

__len__?方法
def __len__(self):return len(self.nsp_dataset)
  • 返回數據集的大小,即樣本數量。

__getitem__?方法
def __getitem__(self, idx):sent1, sent2, nsp_label = self.nsp_dataset[idx]sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)seg_ids = self.pad_to_seq_len(seg_ids, 2)mlm_labels = self.pad_to_seq_len(mlm_labels, -100)mask = (mlm_tok_ids != 0)return {"mlm_tok_ids": mlm_tok_ids,"seg_ids": seg_ids,"mask": mask,"mlm_labels": mlm_labels,"nsp_labels": torch.tensor(nsp_label)}
  • 句子編碼

    • sent1_idssent2_ids 分別是兩個句子的詞索引列表。

    • 使用 self.tokenizer.encode 將句子轉換為詞索引,add_special_tokens=False 表示不添加特殊標記([CLS][SEP])。

  • 構建輸入序列

    • tok_ids:將兩個句子的詞索引列表組合成一個序列,中間用 [SEP] 分隔,并在開頭添加 [CLS]

    • seg_ids:段嵌入索引,第一個句子使用 0,第二個句子使用 1

  • MLM任務

    • mlm_tok_idsmlm_labels 是通過 build_mlm_dataset 方法生成的,用于MLM任務。

  • 填充和截斷

    • 使用 pad_to_seq_len 方法將 mlm_tok_idsseg_idsmlm_labels 填充或截斷到 max_length

  • 掩碼

    • mask:生成一個掩碼,用于標記哪些位置是有效的輸入(非填充部分)。

pad_to_seq_len?方法
def pad_to_seq_len(self, seq, pad_value):seq = seq[:self.max_length]pad_num = self.max_length - len(seq)return torch.tensor(seq + pad_num * [pad_value])
設計原因
  • 將序列截斷到 max_length,并用 pad_value 填充到 max_length

build_mlm_dataset?方法
def build_mlm_dataset(self, tok_ids):mlm_tok_ids = tok_ids.copy()mlm_labels = [-100] * len(tok_ids)for i in range(len(tok_ids)):if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:if random.random() < 0.15:mlm_labels[i] = tok_ids[i]if random.random() < 0.8:mlm_tok_ids[i] = self.mask_idelif random.random() < 0.9:mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)return mlm_tok_ids, mlm_labels
  • MLM任務

    • 隨機選擇一些詞(概率為15%),并將它們替換為 [MASK](80%)、隨機詞(10%)或保持不變(10%)。

    • mlm_labels 用于存儲被替換詞的真實索引,未被替換的位置標記為 -100(PyTorch中忽略計算損失的標記)。

Bert完整代碼(標紅部分為本節所提到部分)

import re
import math
import torch
import random
import torch.nn as nnfrom transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader# nn.TransformerEncoderLayerclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.o_proj = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:mask = mask.unsqueeze(1).unsqueeze(1)atten_scores = atten_scores.masked_fill(mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)return self.dropout(self.o_proj(out))class FeedForward(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.W1 = nn.Linear(d_model, dff)self.act = nn.GELU()self.W2 = nn.Linear(dff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.W2(self.dropout(self.act(self.W1(x))))class TransformerEncoderBlock(nn.Module):def __init__(self, d_model, num_heads, dropout, dff):super().__init__()self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)self.ffn_block = FeedForward(d_model, dff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, mask=None):res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))return res2class BertModel(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.seg_emb = nn.Embedding(3, d_model)self.pos_emb = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, dropout, dff)for _ in range(N_blocks)])self.norm = nn.LayerNorm(d_model)self.drop = nn.Dropout(dropout)def forward(self, x, seg_ids, mask):pos = torch.arange(x.shape[1])tok_emb = self.tok_emb(x)seg_emb = self.seg_emb(seg_ids)pos_emb = self.pos_emb(pos)x = tok_emb + seg_emb + pos_embfor layer in self.layers:x = layer(x, mask)x = self.norm(x)return xclass BERT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.bert = BertModel(vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff)self.mlm_head = nn.Linear(d_model, vocab_size)self.nsp_head = nn.Linear(d_model, 2)def forward(self, mlm_tok_ids, seg_ids, mask):bert_out = self.bert(mlm_tok_ids, seg_ids, mask)cls_token = bert_out[:, 0, :]mlm_logits = self.mlm_head(bert_out)nsp_logits = self.nsp_head(cls_token)return mlm_logits, nsp_logitsdef read_data(file):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().replace("\n", "")corpus = re.split(r'[。,“”:;!、]', data)corpus = [sentence for sentence in corpus if sentence.strip()]return corpusdef create_nsp_dataset(corpus):nsp_dataset = []for i in range(len(corpus)-1):next_sentence = corpus[i+1]rand_id = random.randint(0, len(corpus) - 1)while abs(rand_id - i) <= 1:rand_id = random.randint(0, len(corpus) - 1)negt_sentence = corpus[rand_id]nsp_dataset.append((corpus[i], next_sentence, 1)) # 正樣本nsp_dataset.append((corpus[i], negt_sentence, 0)) # 負樣本return nsp_datasetclass BERTDataset(Dataset):def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):self.nsp_dataset = nsp_datasetself.tokenizer = tokenizerself.max_length = max_lengthself.cls_id = tokenizer.cls_token_idself.sep_id = tokenizer.sep_token_idself.pad_id = tokenizer.pad_token_idself.mask_id = tokenizer.mask_token_iddef __len__(self):return len(self.nsp_dataset)def __getitem__(self, idx):sent1, sent2, nsp_label = self.nsp_dataset[idx]sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)seg_ids = self.pad_to_seq_len(seg_ids, 2)mlm_labels = self.pad_to_seq_len(mlm_labels, -100)mask = (mlm_tok_ids != 0)return {"mlm_tok_ids": mlm_tok_ids,"seg_ids": seg_ids,"mask": mask,"mlm_labels": mlm_labels,"nsp_labels": torch.tensor(nsp_label)}def pad_to_seq_len(self, seq, pad_value):seq = seq[:self.max_length]pad_num = self.max_length - len(seq)return torch.tensor(seq + pad_num * [pad_value])def build_mlm_dataset(self, tok_ids):mlm_tok_ids = tok_ids.copy()mlm_labels = [-100] * len(tok_ids)for i in range(len(tok_ids)):if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:if random.random() < 0.15:mlm_labels[i] = tok_ids[i]if random.random() < 0.8:mlm_tok_ids[i] = self.mask_idelif random.random() < 0.9:mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)return mlm_tok_ids, mlm_labelsif __name__ == "__main__":data_file = "4.10-BERT/背影.txt"model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"tokenizer = BertTokenizer.from_pretrained(model_path)corpus = read_data(data_file)max_length = 25 # len(max(corpus, key=len))print("Max length of dataset: {}".format(max_length))nsp_dataset = create_nsp_dataset(corpus)trainset = BERTDataset(nsp_dataset, tokenizer, max_length)batch_size = 16trainloader = DataLoader(trainset, batch_size, shuffle=True)vocab_size = tokenizer.vocab_sized_model = 768N_blocks = 2num_heads = 12dropout = 0.1dff = 4*d_modelmodel = BERT(vocab_size, d_model, max_length, N_blocks, num_heads, dropout, dff)lr = 1e-3optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 20for epoch in range(epochs):for batch in trainloader:batch_mlm_tok_ids = batch["mlm_tok_ids"]batch_seg_ids = batch["seg_ids"]batch_mask = batch["mask"]batch_mlm_labels = batch["mlm_labels"]batch_nsp_labels = batch["nsp_labels"]mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)loss = loss_mlm + loss_nsploss.backward()optim.step()optim.zero_grad()print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))passpass

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/901029.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/901029.shtml
英文地址,請注明出處:http://en.pswp.cn/news/901029.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

“實時滾動”插件:一個簡單的基于vue.js的無縫滾動

1、參考連接&#xff1a; 安裝 | vue-seamless-scroll 2、使用步驟&#xff1a; 第一步&#xff1a;安裝 yarn add vue-seamless-scroll 第二步&#xff1a;引入 import vueSeamlessScroll from vue-seamless-scroll/src 第三步&#xff1a;注冊 components: { vueSeamless…

【藍橋杯】賽前練習

1. 排序 import os import sysn=int(input()) data=list(map(int,input().split(" "))) data.sort() for d in data:print(d,end=" ") print() for d in data[::-1]:print(d,end=" ")2. 走迷宮BFS import os import sys from collections import…

pyTorch-遷移學習-學習率衰減-四種天氣圖片多分類問題

目錄 1.導包 2.加載數據、拼接訓練、測試數據的文件夾路徑 3.數據預處理 3.1 transforms.Compose數據轉化 3.2分類存儲的圖片數據創建dataloader torchvision.datasets.ImageFolder torch.utils.data.DataLoader 4.加載預訓練好的模型(遷移學習) 4.1固定、修改預訓練…

第十四屆藍橋杯大賽軟件賽國賽Python大學B組題解

文章目錄 彈珠堆放劃分偶串交易賬本背包問題翻轉最大階梯最長回文前后綴貿易航線困局 彈珠堆放 遞推式 a i a i ? 1 i a_ia_{i-1}i ai?ai?1?i&#xff0c; n 20230610 n20230610 n20230610非常小&#xff0c;直接模擬 答案等于 494 494 494 劃分 因為總和為 1 e 6 1e6…

Python 和 JavaScript兩種語言的相似部分-由DeepSeek產生

Python 和 JavaScript 作為兩種流行的編程語言&#xff0c;雖然在設計目標和應用場景上有差異&#xff08;Python 偏向后端和腳本&#xff0c;JavaScript 偏向前端和動態交互&#xff09;&#xff0c;但它們的語法存在許多相似之處。以下是兩者在語法上的主要共同點及對比&…

改善 Maven 的依賴性

大家好&#xff0c;這里是架構資源棧&#xff01;點擊上方關注&#xff0c;添加“星標”&#xff0c;一起學習大廠前沿架構&#xff01; 建議使用mvn dependency:analyze命令來擺脫已聲明但未使用的依賴項&#xff1a; 還有另一個用例&#xff0c; mvn dependency:analyze 它可…

【SQL】子查詢詳解(附例題)

子查詢 子查詢的表示形式為&#xff1a;(SELECT 語句)&#xff0c;它是IN、EXISTS等運算符的運算數&#xff0c;它也出現于FROM子句和VALUES子句。包含子查詢的查詢叫做嵌套查詢。嵌套查詢分為相關嵌套查詢和不想關嵌套查詢 WHERE子句中的子查詢 比較運算符 子查詢的結果是…

Stable Diffusion 擴展知識實操整合

本文的例子都是基于秋葉整合包打開的webui實現的 一、ADetailer——改善人臉扭曲、惡心 After detailer插件可以自動檢測生成圖片的人臉&#xff0c;針對人臉自動上蒙版&#xff0c;自動進行重繪&#xff0c;整個流程一氣呵成&#xff0c;因此可以避免許多重復的操作。除此之…

freertos內存管理簡要概述

概述 內存管理的重要性 在嵌入式系統中&#xff0c;內存資源通常是有限的。合理的內存管理可以確保系統高效、穩定地運行&#xff0c;避免因內存泄漏、碎片化等問題導致系統崩潰或性能下降。FreeRTOS 的內存管理機制有助于開發者靈活地分配和釋放內存&#xff0c;提高內存利用…

按規則批量修改文件擴展名、刪除擴展名或添加擴展名

文件的擴展名是多種多樣的&#xff0c;有些不同文件的擴展名之間相互是可以直接轉換的。我們工作當中最常見的就是 doc 與 docx、xls 與 xlsx、jpg 與 jpeg、html 與 htm 等等&#xff0c;這些格式在大部分場景下都是可以相互轉換 能直接兼容的。我們今天要介紹的就是如何按照一…

熱門面試題第15天|最大二叉樹 合并二叉樹 驗證二叉搜索樹 二叉搜索樹中的搜索

654.最大二叉樹 力扣題目地址(opens new window) 給定一個不含重復元素的整數數組。一個以此數組構建的最大二叉樹定義如下&#xff1a; 二叉樹的根是數組中的最大元素。左子樹是通過數組中最大值左邊部分構造出的最大二叉樹。右子樹是通過數組中最大值右邊部分構造出的最大…

MySQL學習筆記7【InnoDB】

Innodb 1. 架構 1.1 內存部分 buffer pool 緩沖池是主存中的第一個區域&#xff0c;里面可以緩存磁盤上經常操作的真實數據&#xff0c;在執行增刪查改操作時&#xff0c;先操作緩沖池中的數據&#xff0c;然后以一定頻率刷新到磁盤&#xff0c;這樣操作明顯提升了速度。 …

RNN、LSTM、GRU匯總

RNN、LSTM、GRU匯總 0、論文匯總1.RNN論文2、LSTM論文3、GRU4、其他匯總 1、發展史2、配置和架構1.配置2.架構 3、基本結構1.神經元2.RNN1. **RNN和前饋網絡區別&#xff1a;**2. 計算公式&#xff1a;3. **梯度消失:**4. **RNN類型**:&#xff08;查看發展史&#xff09;5. **…

django數據遷移操作受阻

錯誤信息&#xff1a; django.db.utils.OperationalError: (1227, Access denied; you need (at least one of) the SYSTEM_VARIABLES_ADMIN or SESSION_VARIABLES_ADMIN privilege(s) for this operation)根據錯誤信息分析&#xff0c;該問題是由于MySQL用戶 缺乏SYSTEM_VARI…

WinForm真入門(13)——ListBox控件詳解

WinForm ListBox 詳解與案例 一、核心概念 ?ListBox? 是 Windows 窗體中用于展示可滾動列表項的控件&#xff0c;支持單選或多選操作&#xff0c;適用于需要用戶從固定數據集中選擇一項或多項的場景?。 二、核心屬性 屬性說明?Items?管理列表項的集合&#xff0c;支持動…

局域網內文件共享的實用軟件推薦

軟件介紹 在日常辦公、學習或家庭網絡環境里&#xff0c;局域網內文件共享是個常見需求。有一款免費的局域網共享軟件非常適合這種場景。 這款局域網共享軟件使用起來非常簡單&#xff0c;不需要安裝&#xff0c;直接點擊就能使用。 它的操作流程簡單易懂&#xff0c;用戶只要…

ViewModel vs AndroidViewModel:核心區別與使用場景詳解

在 Android 的 MVVM 架構中&#xff0c;ViewModel 和 AndroidViewModel 都是用于管理 UI 相關數據的組件&#xff0c;但二者有一些關鍵區別&#xff1a; 1. ViewModel 基本用途&#xff1a;用于存儲和管理與 UI 相關的數據&#xff0c;生命周期與 Activity/Fragment 解耦&…

C語言--求n以內的素數(質數)

求n以內的素數&#xff0c;可以用試除法或者埃拉托斯特尼篩法&#xff08;埃氏篩法&#xff09; 文章目錄 試除法埃拉托斯特尼篩法&#xff08;埃氏篩法&#xff09;兩種方法測試運行效率 輸入&#xff1a;數字n 輸出&#xff1a;n以內所有的素數 不管是哪個方法&#xff0c;都…

Skynet.socket 函數族使用詳解

目錄 Skynet.socket 函數族使用詳解核心功能分類一、TCP 連接管理1. 監聽端口2. 建立連接3. 關閉連接 二、數據讀寫操作1. 阻塞式讀取2. 寫入數據2.1 socket.write(fd, data) 的返回值2.2 示例代碼2.3 關鍵注意事項2.4 與其他函數的區別2.5 底層原理2.6 總結 三、UDP 處理1. 創…

Unity Addressables資源生命周期自動化監控技術詳解

一、Addressables資源生命周期管理痛點 1. 常見資源泄漏場景 泄漏類型典型表現檢測難度隱式引用泄漏腳本持有AssetReference未釋放高異步操作未處理AsyncOperationHandle未釋放中循環依賴泄漏資源相互引用無法釋放極高事件訂閱泄漏未取消事件監聽導致對象保留高 2. 傳統管理…