探索大語言模型(LLM):Transformer 與 BERT從原理到實踐

Transformer 與 BERT:從原理到實踐

  • 前言
  • 一、背景介紹
  • 二、核心公式推導
    • 1. 注意力機制(Attention Mechanism)
    • 2. 多頭注意力機制(Multi-Head Attention)
    • 3. Transformer 編碼器(Transformer Encoder)
    • 4. BERT 的預訓練任務
  • 三、代碼實現
    • 1. 注意力機制
    • 2. 多頭注意力機制
    • 3. Transformer 編碼器層
    • 4. Transformer 編碼器
    • 5. BERT 模型
  • 四、總結


前言

在自然語言處理(NLP)的發展歷程中,Transformer 和 BERT 無疑是具有里程碑意義的技術。它們的出現,徹底改變了 NLP 領域的研究和應用格局。本文將深入探討 Transformer 和 BERT 的背景、核心公式推導,并提供代碼實現,幫助大家更好地理解和應用這兩項技術。

一、背景介紹

在 Transformer 出現之前,循環神經網絡(RNN)及其變體長短時記憶網絡(LSTM)、門控循環單元(GRU)等在 NLP 任務中占據主導地位。RNN 能夠處理序列數據,通過隱狀態傳遞信息,從而捕捉上下文依賴關系。然而,RNN 存在嚴重的梯度消失和梯度爆炸問題,使得訓練深層網絡變得困難。此外,RNN 的順序計算特性導致其難以并行化,處理長序列時效率低下。

為了解決這些問題,2017 年谷歌團隊在論文《Attention Is All You Need》中提出了 Transformer 架構。Transformer 完全摒棄了循環結構,采用多頭注意力機制(Multi-Head Attention)替代 RNN,實現了并行計算,大幅提高了訓練效率。同時,多頭注意力機制能夠更好地捕捉序列中的長距離依賴關系,在機器翻譯、文本生成等多個 NLP 任務中取得了優異的性能。

BERT(Bidirectional Encoder Representations from Transformers)則是基于 Transformer 的預訓練語言模型,由谷歌在 2018 年提出。與傳統的語言模型(如 Word2Vec、GPT)不同,BERT 采用雙向 Transformer 編碼器,能夠同時利用上下文信息,學習到更豐富的語義表示。通過在大規模文本數據上進行預訓練,并在特定任務上進行微調,BERT 在問答系統、文本分類、命名實體識別等眾多 NLP 任務中刷新了當時的最優成績,開啟了預訓練模型在 NLP 領域的新時代。

二、核心公式推導

1. 注意力機制(Attention Mechanism)

注意力機制的核心思想是根據輸入序列的不同部分對當前任務的重要程度,分配不同的權重,從而聚焦于關鍵信息。其計算過程如下:
給定查詢向量 (Q),鍵向量 (K) 和值向量 (V),注意力分數 (scores) 計算為:
s c o r e s = Q K T d k scores = \frac{QK^T}{\sqrt{d_k}} scores=dk? ?QKT?
其中, d k d_k dk? 是鍵向量 K K K 的維度, d k \sqrt{d_k} dk? ? 用于縮放,防止分數過大導致 softmax 函數梯度消失。
通過 softmax 函數對注意力分數進行歸一化,得到注意力權重 (attention weights):
a t t e n t i o n _ w e i g h t s = s o f t m a x ( s c o r e s ) attention\_weights = softmax(scores) attention_weights=softmax(scores)
最后,加權求和得到注意力輸出 A t t e n t i o n ( Q , K , V ) Attention(Q, K, V) Attention(Q,K,V)
A t t e n t i o n ( Q , K , V ) = a t t e n t i o n _ w e i g h t s ? V Attention(Q, K, V) = attention\_weights \cdot V Attention(Q,K,V)=attention_weights?V

2. 多頭注意力機制(Multi-Head Attention)

多頭注意力機制通過多個獨立的注意力頭并行計算,從不同角度捕捉輸入序列的特征,然后將各個頭的輸出拼接并線性變換得到最終輸出。具體計算過程如下:
首先,將輸入 X X X分別通過三個線性變換得到 Q Q Q K K K V V V
Q = X W Q K = X W K V = X W V Q = XW^Q\\ K = XW^K\\ V = XW^V Q=XWQK=XWKV=XWV
其中, W Q W^Q WQ W K W^K WK W V W^V WV 是可學習的權重矩陣。
然后,將 Q Q Q K K K V V V 分割成 h h h 個頭部(head),每個頭部的維度為 d k / h d_{k/h} dk/h? d v / h d_{v/h} dv/h?
Q i = Q ( i ? 1 ) d k / h : i d k / h K i = K ( i ? 1 ) d k / h : i d k / h V i = V ( i ? 1 ) d v / h : i d v / h Q_i = Q_{(i-1)d_{k/h}:id_{k/h}} \\ K_i = K_{(i-1)d_{k/h}:id_{k/h}} \\ V_i = V_{(i-1)d_{v/h}:id_{v/h}} Qi?=Q(i?1)dk/h?:idk/h??Ki?=K(i?1)dk/h?:idk/h??Vi?=V(i?1)dv/h?:idv/h??
對每個頭部分別計算注意力輸出:
h e a d i = A t t e n t i o n ( Q i , K i , V i ) head_i = Attention(Q_i, K_i, V_i) headi?=Attention(Qi?,Ki?,Vi?)
將所有頭部的輸出拼接起來:
c o n c a t ( h e a d 1 , . . . , h e a d h ) concat(head_1, ..., head_h) concat(head1?,...,headh?)
最后,通過一個線性變換得到多頭注意力機制的最終輸出:
M u l t i H e a d A t t e n t i o n ( X ) = W O ? c o n c a t ( h e a d 1 , . . . , h e a d h ) MultiHeadAttention(X) = W^O \cdot concat(head_1, ..., head_h) MultiHeadAttention(X)=WO?concat(head1?,...,headh?)
其中, W O W^O WO是可學習的權重矩陣。

3. Transformer 編碼器(Transformer Encoder)

Transformer 編碼器由多個相同的層堆疊而成,每個層包含兩個子層:多頭注意力機制子層和前饋神經網絡子層。每個子層都使用了殘差連接(Residual Connection)和層歸一化(Layer Normalization)。
輸入 (X) 首先經過多頭注意力機制子層:
X 1 = L a y e r N o r m ( X + M u l t i H e a d A t t e n t i o n ( X ) ) X_1 = LayerNorm(X + MultiHeadAttention(X)) X1?=LayerNorm(X+MultiHeadAttention(X))
然后, X 1 X_1 X1? 經過前饋神經網絡子層:
X 2 = L a y e r N o r m ( X 1 + F F N ( X 1 ) ) X_2 = LayerNorm(X_1 + FFN(X_1)) X2?=LayerNorm(X1?+FFN(X1?))
其中, F F N ( X ) FFN(X) FFN(X)是前饋神經網絡,通常由兩個線性層和一個激活函數組成:
F F N ( X ) = m a x ( 0 , X W 1 + b 1 ) W 2 + b 2 FFN(X) = max(0, XW_1 + b_1)W_2 + b_2 FFN(X)=max(0,XW1?+b1?)W2?+b2?

4. BERT 的預訓練任務

BERT 采用了兩個預訓練任務:掩碼語言模型(Masked Language Model,MLM)和下一句預測(Next Sentence Prediction,NSP)。
在掩碼語言模型任務中,隨機將輸入文本中的一些單詞替換為 [MASK] 標記,然后讓模型預測這些被掩碼的單詞。例如,對于句子 “I love natural language processing”,可能會將 “love” 替換為 [MASK],模型需要根據上下文預測出 “love”。
下一句預測任務用于學習句子之間的關系。給定一對句子,判斷第二個句子是否是第一個句子的下一句。例如,句子對 “今天天氣很好。我們去公園散步吧。” 標簽為正例,而 “今天天氣很好。我喜歡吃蘋果。” 標簽為負例。
BERT 通過最小化這兩個任務的損失函數進行預訓練,損失函數為:
L = L M L M + L N S P L = L_{MLM} + L_{NSP} L=LMLM?+LNSP?

三、代碼實現

下面我們使用 PyTorch 實現一個簡單的 Transformer 編碼器和 BERT 預訓練模型。

1. 注意力機制

import torch
import torch.nn as nndef scaled_dot_product_attention(Q, K, V):d_k = K.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))attention_weights = nn.functional.softmax(scores, dim=-1)return torch.matmul(attention_weights, V)

2. 多頭注意力機制

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.depth = d_model // num_headsself.WQ = nn.Linear(d_model, d_model)self.WK = nn.Linear(d_model, d_model)self.WV = nn.Linear(d_model, d_model)self.WO = nn.Linear(d_model, d_model)def split_heads(self, x):batch_size = x.size(0)return x.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)def forward(self, X):Q = self.split_heads(self.WQ(X))K = self.split_heads(self.WK(X))V = self.split_heads(self.WV(X))attention = scaled_dot_product_attention(Q, K, V)concatenated_attention = attention.transpose(1, 2).contiguous().view(-1, self.d_model)return self.WO(concatenated_attention)

3. Transformer 編碼器層

class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, num_heads):super(TransformerEncoderLayer, self).__init__()self.attention = MultiHeadAttention(d_model, num_heads)self.ffn = nn.Sequential(nn.Linear(d_model, d_model * 4),nn.ReLU(),nn.Linear(d_model * 4, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)def forward(self, X):attn_output = self.attention(X)X = self.layernorm1(X + attn_output)ffn_output = self.ffn(X)return self.layernorm2(X + ffn_output)

4. Transformer 編碼器

class TransformerEncoder(nn.Module):def __init__(self, num_layers, d_model, num_heads):super(TransformerEncoder, self).__init__()self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, num_heads) for _ in range(num_layers)])def forward(self, X):for layer in self.layers:X = layer(X)return X

5. BERT 模型

class BERT(nn.Module):def __init__(self, vocab_size, num_layers, d_model, num_heads):super(BERT, self).__init__()self.embedding = nn.Embedding(vocab_size, d_model)self.transformer = TransformerEncoder(num_layers, d_model, num_heads)self.mlm_head = nn.Linear(d_model, vocab_size)self.nsp_head = nn.Linear(d_model, 2)def forward(self, X, masked_indices=None):X = self.embedding(X)X = self.transformer(X)if masked_indices is not None:masked_X = torch.gather(X, 1, masked_indices.unsqueeze(-1).repeat(1, 1, X.size(-1)))mlm_logits = self.mlm_head(masked_X)else:mlm_logits = Nonensp_logits = self.nsp_head(X[:, 0])return mlm_logits, nsp_logits

以上代碼實現了 Transformer 編碼器和 BERT 模型的基本結構。在實際應用中,還需要進行數據預處理、模型訓練和評估等步驟。

四、總結

Transformer 和 BERT 作為 NLP 領域的重要技術,以其獨特的架構和強大的性能,推動了 NLP 技術的快速發展。通過本文對 Transformer 和 BERT 的背景介紹、公式推導和代碼實現,相信大家對它們有了更深入的理解。隨著研究的不斷深入,Transformer 和 BERT 的應用場景也在不斷拓展,未來它們將在更多領域發揮重要作用。希望本文能為大家的學習和研究提供幫助,歡迎大家在評論區交流討論。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902073.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902073.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902073.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

計算機網絡八股——HTTP協議與HTTPS協議

目錄 HTTP1.1簡述與特性 1. 報文清晰易讀 2. 靈活和易于擴展 3. ?狀態 Cookie和Session 4. 明?傳輸、不安全 HTTP協議發展過程 HTTP/1.1的不足 HTTP/2.0 HTTP/3.0 HTTPS協議 HTTP協議和HTTPS協議的區別 HTTPS中的加密方式 HTTPS中建立連接的方式 前言&#xff…

QML中的3D功能--入門開發

Qt Quick 提供了強大的 3D 功能支持,主要通過 Qt 3D 模塊實現。以下是 QML 中開發 3D 應用的全面指南。 1. 基本配置 環境要求 Qt 5.10 或更高版本(推薦 Qt 6.x) 啟用 Qt 3D 模塊 支持 OpenGL 的硬件 項目配置 在 .pro 文件中添加: QT += 3dcore 3drender 3dinput 3dex…

Git合并分支的兩種常用方式`git merge`和`git cherry-pick`

Git合并分支的兩種常用方式git merge和git cherry-pick 寫在前面1. git merge用途工作方式使用git命令方式合并使用idea工具方式合并 2. git cherry-pick用途工作方式使用git命令方式合并使用idea工具方式合并 3. 區別總結 寫在前面 一般我們使用git合并分支常用的就是git mer…

Web三漏洞學習(其三:rce漏洞)

靶場:NSSCTF 三、RCE漏洞 1、概述 在Web應用開發中會讓應用調用代碼執行函數或系統命令執行函數處理,若應用對用戶的輸入過濾不嚴,容易產生遠程代碼執行漏洞或系統命令執行漏洞 所以常見的RCE漏洞函數又分為代碼執行函數和系統命令執行函數…

從零開始:Python運行環境之VSCode與Anaconda安裝配置全攻略 (1)

從零開始:Python 運行環境之 VSCode 與 Anaconda 安裝配置全攻略 在當今數字化時代,Python 作為一種功能強大且易于學習的編程語言,被廣泛應用于數據科學、人工智能、Web 開發等眾多領域。為了順利開啟 Python 編程之旅,搭建一個穩…

從FPGA實現角度介紹DP_Main_link主通道原理

DisplayPort(簡稱DP)是一個標準化的數字式視頻接口標準,具有三大基本架構包含影音傳輸的主要通道(Main Link)、輔助通道(AUX)、與熱插拔(HPD)。 Main Link:用…

嵌入式軟件--stm32 DAY 2

大家學習嵌入式的時候,多多學習用KEIL寫代碼,雖然作為編譯器,大家常用vscode等常用工具關聯編碼,但目前keil仍然是主流工具之一,學習掌握十分必要。 1.再次創建項目 1.1編譯器自動生成文件 1.2初始文件 這樣下次創建新…

游戲引擎學習第234天:實現基數排序

回顧并為今天的內容設定背景 我們今天繼續進行排序的相關,雖然基本已經完成了,但還是想收尾一下,讓整個流程更完整。其實這次排序只是個借口,主要是想順便聊一聊一些計算機科學的知識點,這些內容在我們項目中平時不會…

計算機網絡——常見的網絡攻擊手段

什么是XSS攻擊,如何避免? XSS 攻擊,全稱跨站腳本攻擊(Cross-Site Scripting),這會與層疊樣式表(Cascading Style Sheets, CSS)的縮寫混淆,因此有人將跨站腳本攻擊縮寫為XSS。它指的是惡意攻擊者往Web頁面…

Agent的九種設計模式 介紹

Agent的九種設計模式 介紹 一、ReAct模式 原理:將推理(Reasoning)和行動(Acting)相結合,使Agent能夠在推理的指導下采取行動,并根據行動的結果進一步推理,形成一個循環。Agent通過生成一系列的思維鏈(Thought Chains)來明確推理步驟,并根據推理結果執行相應的動作,…

LeetCode 熱題 100:回溯

46. 全排列 給定一個不含重復數字的數組 nums ,返回其 所有可能的全排列 。你可以 按任意順序 返回答案。 示例 1: 輸入:nums [1,2,3] 輸出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 輸入&#xff…

cJSON_Print 和 cJSON_PrintUnformatted的區別

cJSON_Print 和 cJSON_PrintUnformatted 是 cJSON 庫中用于將 cJSON 對象轉換為 JSON 字符串的兩個函數,它們的區別主要在于輸出的格式: 1. cJSON_Print 功能:將 cJSON 對象轉換為格式化的 JSON 字符串。 特點: 輸出的 JSON 字符…

A股周度復盤與下周策略 的deepseek提示詞模板

以下是反向整理的股票大盤分析提示詞模板,采用結構化框架數據占位符設計,可直接套用每周市場數據: 請根據一下markdown格式的模板,幫我檢索整理并輸出本周股市復盤和下周投資策略 【A股周度復盤與下周策略提示詞模板】 一、市場…

Linux下使用C++獲取硬件信息

目錄 方法獲取CPU信息:讀取"/proc/cpuinfo"文件獲取磁盤信息:讀取"/proc/diskstats"文件獲取BIOS信息有兩種方法:1、讀取文件;2、使用dmidecode命令獲取主板信息有兩種方法:1、讀取文件&#xff1…

BootStrap:進階使用(其二)

今天我要講述的是在BootStrap中第二篇關于進一步使用的方法與代碼舉例; 分頁: 對于一些大型網站而言,分頁是一個很有必要的存在,如果當數據內容過大時,則需要分頁來分擔一些,這可以使得大量內容能整合并全面地展示&a…

【技術派后端篇】技術派中的白名單機制:基于Redis的Set實現

在技術派社區中,為了保證文章的質量和社區的良性發展,所有發布的文章都需要經過審核。然而,并非所有作者的文章都需要審核,我們通過白名單機制來優化這一流程。本文將詳細介紹技術派中白名單的實現方式,以及如何利用Re…

TRAE.AI 國際版本

國際版下載地址: https://www.trae.ai/https://www.trae.ai/ 國際版本優勢:提供更多高校的AI助手模型 Claude-3.5-Sonnet Claude-3.7-Sonnet Gemini-2.5-Pro GPT-4.1 GPT-40 DeepSeek-V3-0324DeepSeek-V3DeepSeek-Reasoner(R1)

關于支付寶網頁提示非官方網頁

關于支付寶網站提示 非官方網站 需要找官方添加白名單 下面可以直接用自己的郵箱去發送申請 支付寶提示“非支付寶官方網頁,請確認是否繼續訪問”通常是因為支付寶的安全機制檢測到您訪問的頁面不是支付寶官方頁面,這可能是由于域名或頁面內容不符合支…

【今日三題】打怪(模擬) / 字符串分類(字符串哈希) / 城市群數量(dfs)

??個人主頁&#xff1a;小羊 ??所屬專欄&#xff1a;每日兩三題 很榮幸您能閱讀我的文章&#xff0c;誠請評論指點&#xff0c;歡迎歡迎 ~ 目錄 打怪(模擬)字符串分類(字符串哈希)城市群數量(dfs) 打怪(模擬) 打怪 #include <iostream> using namespace std;int …

npm install 版本過高引發錯誤,請添加 --legacy-peer-deps

起因&#xff1a;由于使用"react": "^19.0.0", 第三方包要低版本react&#xff0c;錯解決方法&#xff01; npm install --save emoji-mart emoji-mart/data emoji-mart/react npm install --save emoji-mart emoji-mart/data emoji-mart/react npm err…