基于Transformer的機器翻譯——模型篇

1.模型結構

本案例整體采用transformer論文中提出的結構,部分設置做了調整。transformer網絡結構介紹可參考博客——入門級別的Transformer模型介紹,這里著重介紹其代碼實現。
模型的整體結構,包括詞嵌入層,位置編碼,編碼器,解碼器、輸出層部分。

2.詞嵌入層

詞嵌入層用于將token轉化為詞向量,該層可直接調用nn模塊中的Embedding方法。該方法主要包括兩個參數,分別表示詞表的大小(vocab_size)和詞嵌入的維度(emb_size),同時為了訓練更穩定,加入了縮放因子dk\sqrt {d_k}dk??,代碼如下:

class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 詞嵌入層:將詞索引映射到emb_size維的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 記錄嵌入維度(用于縮放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 將詞索引轉換為詞向量,并乘以√emb_size(縮放,穩定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

3.位置編碼

位置編碼層用于給序列添加位置信息,解決自注意力機制無法感知序列順序的問題。公式為:
PE(pos,2i)=sin(pos10002id)PE(pos,2i)=sin(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i)=sin(1000d2i?pos?)
PE(pos,2i+1)=cos(pos10002id)PE(pos,2i+1)=cos(\frac{pos}{1000\frac{2i}{d}})PE(pos,2i+1)=cos(1000d2i?pos?)
代碼表示如下:

class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 計算位置編碼的衰減因子(控制正弦/余弦函數的頻率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置編碼矩陣(形狀:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶數列用正弦函數填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇數列用余弦函數填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 調整維度(添加批次維度,便于與詞嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout層(正則化,防止過擬合)self.dropout = nn.Dropout(dropout)# 注冊為緩沖區(模型保存/加載時自動處理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 將詞嵌入向量與位置編碼相加,并應用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])

4.編碼器

由于編碼器部分是通過堆疊多個子編碼器層所構成的,子編碼器包括:多頭自注意力層、殘差連接與歸一化、前饋網絡三部分,該部分代碼全部被封裝成TransformerEncoderLayer函數中,使用時只需要傳遞相應超參數即可,如詞嵌入維度、多頭注意力的頭數、前饋網絡的隱含層維度,代碼實現為:

# 定義編碼器層(單頭注意力→多頭注意力→前饋網絡)
encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 輸入特征維度(與詞嵌入維度一致)nhead=NHEAD,            # 多頭注意力的頭數dim_feedforward=dim_feedforward  # 前饋網絡隱藏層維度
)
# 堆疊多層編碼器層形成完整編碼器
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

5.解碼器

解碼器同編碼器類似,代碼可以表述為:

# 定義解碼器層(掩碼多頭注意力→編碼器-解碼器多頭注意力→前饋網絡)
decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 輸入特征維度(與詞嵌入維度一致)nhead=NHEAD,            # 多頭注意力頭數(與編碼器一致)dim_feedforward=dim_feedforward  # 前饋網絡隱藏層維度
)
# 堆疊多層解碼器層形成完整解碼器
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

6.輸出層

輸出通過線性層得到每個單詞的得分,可直接通過Linear層直接實現。

7.大體代碼

基于上述介紹,完整代碼如下:

from torch.nn import (TransformerEncoder, TransformerDecoder,TransformerEncoderLayer, TransformerDecoderLayer)class Seq2SeqTransformer(nn.Module):"""基于Transformer的序列到序列翻譯模型(日中機器翻譯核心模塊)包含編碼器(處理源語言序列)和解碼器(生成目標語言序列)"""def __init__(self, num_encoder_layers: int, num_decoder_layers: int,emb_size: int, src_vocab_size: int, tgt_vocab_size: int,dim_feedforward: int = 512, dropout: float = 0.1):"""初始化Transformer模型參數和組件:param num_encoder_layers: 編碼器層數(論文中通常為6,此處根據計算資源調整):param num_decoder_layers: 解碼器層數(與編碼器層數一致):param emb_size: 詞嵌入維度(對應Transformer的d_model,需與多頭注意力維度匹配):param src_vocab_size: 源語言(日語)詞表大小:param tgt_vocab_size: 目標語言(中文)詞表大小:param dim_feedforward: 前饋網絡隱藏層維度(通常為4*d_model):param dropout:  dropout概率(用于正則化,防止過擬合)"""super(Seq2SeqTransformer, self).__init__()# 定義編碼器層(單頭注意力→多頭注意力→前饋網絡)encoder_layer = TransformerEncoderLayer(d_model=emb_size,       # 輸入特征維度(與詞嵌入維度一致)nhead=NHEAD,            # 多頭注意力的頭數(需滿足 emb_size % nhead == 0)dim_feedforward=dim_feedforward  # 前饋網絡隱藏層維度)# 堆疊多層編碼器層形成完整編碼器self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)# 定義解碼器層(掩碼多頭注意力→編碼器-解碼器多頭注意力→前饋網絡)decoder_layer = TransformerDecoderLayer(d_model=emb_size,       # 輸入特征維度(與詞嵌入維度一致)nhead=NHEAD,            # 多頭注意力頭數(與編碼器一致)dim_feedforward=dim_feedforward  # 前饋網絡隱藏層維度)# 堆疊多層解碼器層形成完整解碼器self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)# 生成器:將解碼器輸出映射到目標詞表(預測每個位置的目標詞)self.generator = nn.Linear(emb_size, tgt_vocab_size)# 源語言詞嵌入層(將詞索引轉換為連續向量)self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)# 目標語言詞嵌入層(與源語言共享嵌入層可提升效果,此處未共享)self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)# 位置編碼層(注入序列位置信息,解決Transformer的位置無關性)self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,tgt_mask: Tensor, src_padding_mask: Tensor,tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):"""前向傳播(訓練時使用教師強制,輸入完整目標序列):param src: 源語言序列張量(形狀:[seq_len, batch_size]):param trg: 目標語言序列張量(形狀:[seq_len, batch_size]):param src_mask: 源序列注意力掩碼(形狀:[seq_len, seq_len],全0表示無掩碼):param tgt_mask: 目標序列掩碼(下三角掩碼,防止關注未來詞):param src_padding_mask: 源序列填充掩碼(標記<pad>位置,形狀:[batch_size, seq_len]):param tgt_padding_mask: 目標序列填充掩碼(標記<pad>位置,形狀:[batch_size, seq_len]):param memory_key_padding_mask: 編碼器輸出的填充掩碼(與src_padding_mask一致):return: 目標序列的詞表概率分布(形狀:[seq_len, batch_size, tgt_vocab_size])"""# 源序列處理:詞嵌入 + 位置編碼src_emb = self.positional_encoding(self.src_tok_emb(src))# 目標序列處理:詞嵌入 + 位置編碼(訓練時使用教師強制,輸入完整目標序列)tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))# 編碼器處理源序列,生成記憶向量(memory)memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)# 解碼器利用記憶向量生成目標序列outs = self.transformer_decoder(tgt_emb,                # 目標序列嵌入(含位置信息)memory,                 # 編碼器輸出的記憶向量tgt_mask,               # 目標序列掩碼(防止未來詞)None,                   # 編碼器-解碼器注意力掩碼(此處未使用)tgt_padding_mask,       # 目標序列填充掩碼(忽略<pad>)memory_key_padding_mask # 記憶向量填充掩碼(與源序列填充掩碼一致))# 通過生成器輸出目標詞表的概率分布return self.generator(outs)def encode(self, src: Tensor, src_mask: Tensor):"""編碼源序列(推理時單獨調用,生成編碼器記憶向量):param src: 源語言序列張量(形狀:[seq_len, batch_size]):param src_mask: 源序列注意力掩碼(形狀:[seq_len, seq_len]):return: 編碼器輸出的記憶向量(形狀:[seq_len, batch_size, emb_size])"""return self.transformer_encoder(self.positional_encoding(self.src_tok_emb(src)),  # 源序列嵌入+位置編碼src_mask  # 源序列注意力掩碼)def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):"""解碼目標序列(推理時逐步生成目標詞):param tgt: 當前已生成的目標序列前綴(形狀:[current_seq_len, batch_size]):param memory: 編碼器輸出的記憶向量(形狀:[seq_len, batch_size, emb_size]):param tgt_mask: 目標序列掩碼(下三角掩碼,防止關注未來詞):return: 解碼器輸出(形狀:[current_seq_len, batch_size, emb_size])"""return self.transformer_decoder(self.positional_encoding(self.tgt_tok_emb(tgt)),  # 目標前綴嵌入+位置編碼memory,  # 編碼器記憶向量tgt_mask  # 目標前綴掩碼(僅允許關注已生成部分))
class PositionalEncoding(nn.Module):def __init__(self, emb_size: int, dropout, maxlen: int = 5000):super(PositionalEncoding, self).__init__()# 計算位置編碼的衰減因子(控制正弦/余弦函數的頻率)den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)# 位置索引(0到maxlen-1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 初始化位置編碼矩陣(形狀:[maxlen, emb_size])pos_embedding = torch.zeros((maxlen, emb_size))# 偶數列用正弦函數填充(pos * den)pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇數列用余弦函數填充(pos * den)pos_embedding[:, 1::2] = torch.cos(pos * den)# 調整維度(添加批次維度,便于與詞嵌入向量相加)pos_embedding = pos_embedding.unsqueeze(-2)# Dropout層(正則化,防止過擬合)self.dropout = nn.Dropout(dropout)# 注冊為緩沖區(模型保存/加載時自動處理)self.register_buffer('pos_embedding', pos_embedding)def forward(self, token_embedding: Tensor):# 將詞嵌入向量與位置編碼相加,并應用Dropoutreturn self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])class TokenEmbedding(nn.Module):def __init__(self, vocab_size: int, emb_size):super(TokenEmbedding, self).__init__()# 詞嵌入層:將詞索引映射到emb_size維的向量self.embedding = nn.Embedding(vocab_size, emb_size)# 記錄嵌入維度(用于縮放)self.emb_size = emb_sizedef forward(self, tokens: Tensor):# 將詞索引轉換為詞向量,并乘以√emb_size(縮放,穩定梯度)return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

結語

至此,模型已完成搭建,后續博客將繼續介紹模型訓練部分的內容,希望本篇博客能夠對你理解transformer有所幫助!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93578.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93578.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93578.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

上位機TCP/IP通信協議層常見問題匯總

以太網 TCP 通信是上位機開發中常用的通信方式&#xff0c;西門子 S7 通信、三菱 MC 通信以及 MQTT、OPC UA、Modbus TCP 等都是其典型應用。為幫助大家更好地理解 TCP 通信&#xff0c;我整理了一套常見問題匯總。一、OSI參考模型與TCP/IP參考模型基于TCP/IP的參考模型將協議分…

搭建ktg-mes

項目地址 該安裝事項&#xff0c;基于當前最新版 2025年8月16日 之前的版本 下載地址&#xff1a; 后端JAVA 前端VUE 后端安裝&#xff1a; 還原數據表 路徑&#xff1a;根目錄/sql/ry_20210908.sql、根目錄/sql/quartz.sql、根目錄/doc/實施文檔/ktgmes-202505180846.sql.g…

uniapp純前端繪制商品分享圖

效果如圖// useMpCustomShareImage.ts interface MpCustomShareImageData {canvasId: stringprice: stringlinePrice: stringgoodsSpecFirmName: stringimage: string }const CANVAS_WIDTH 500 const CANVAS_HEIGHT 400 const BG_IMAGE https://public-scjuchuang.oss-cn-ch…

醋酸鑭:看不見的科技助力

雖然我們每天都在使用各種科技產品&#xff0c;但有些關鍵的化學物質卻鮮為人知。醋酸鑭&#xff0c;就是這樣一種默默為科技進步貢獻力量的“幕后英雄”。它不僅是稀土元素鑭的一種化合物&#xff0c;還在許多高科技領域中發揮著重要作用。今天&#xff0c;讓我們一起來了解這…

蒼穹外賣日記

day 1 windows系統啟動nginx報錯: The system cannot find the path specified 在啟動nginx的時候報錯&#xff1a; /temp/client_body_temp" failed (3: The system cannot find the path specified) 解決辦法&#xff1a; 1.檢查nginx的目錄是否存在中文 &#xff0c;路…

樓宇自控系統賦能建筑全維度管理,實現環境、安全與能耗全面監管

隨著城市化進程加速和綠色建筑理念普及&#xff0c;現代樓宇管理正經歷從粗放式運營向精細化管控的轉型。樓宇自控系統&#xff08;BAS&#xff09;作為建筑智能化的核心載體&#xff0c;通過物聯網、大數據和人工智能技術的深度融合&#xff0c;正在重構建筑管理的全維度框架&…

【HarmonyOS】Window11家庭中文版開啟鴻蒙模擬器失敗提示未開啟Hyoer-V

【HarmonyOS】Window11家庭中文版開啟鴻蒙模擬器失敗提示未開啟Hyoer-V一、問題背景 當鴻蒙模擬器啟動時&#xff0c;提示如下圖所示&#xff1a;因為Hyper-V 僅在 Windows 11 專業版、企業版和教育版中作為預裝功能提供&#xff0c;而家庭版&#xff08;包括中文版&#xff09…

vscode遠程服務器出現一直卡在正在打開遠程和連接超時解決辦法

項目場景&#xff1a; 使用ssh命令或者各種軟件進行遠程服務器之后&#xff0c;結果等到幾分鐘之后自動斷開連接問題解決。vscode遠程服務器一直卡在正在打開遠程狀態問題解決。問題描述 1.連接超時 2.vscode遠程一直卡在正在打開遠程...原因分析&#xff1a;需要修改設置超時斷…

Maven下載和配置-IDEA使用

目錄 一 MAVEN 二 三個倉庫 1. 本地倉庫&#xff08;Local Repository&#xff09; 2. 私有倉庫&#xff08;Private Repository&#xff0c;公司內部倉庫&#xff09; 3. 遠程倉庫&#xff08;Remote Repository&#xff09; 依賴查找流程&#xff08;優先級&#xff09…

Dify實戰應用指南(上傳需求稿生成測試用例)

一、Dify平臺簡介 Dify是一款開源的大語言模型&#xff08;LLM&#xff09;應用開發平臺&#xff0c;融合了“Define&#xff08;定義&#xff09; Modify&#xff08;修改&#xff09;”的設計理念&#xff0c;通過低代碼/無代碼的可視化界面降低技術門檻。其核心價值在于幫助…

學習日志35 python

1 Python 列表切片一、切片完整語法列表切片的基本格式&#xff1a; 列表[start:end:step]start&#xff1a;起始索引&#xff08;包含該位置元素&#xff0c;可省略&#xff09;end&#xff1a;結束索引&#xff08;不包含該位置元素&#xff0c;可省略&#xff09;step&#…

Linux -- 文件【下】

目錄 一、EXT2文件系統 1、宏觀認識 2、塊組內部構成 2.1 Data Block 2.2 i節點表(Inode Table) 2.3 塊位圖&#xff08;Block Bitmap&#xff09; 2.4 inode位圖&#xff08;Inode Bitmap&#xff09; 2.5 GDT&#xff08;Group Descriptor Table&#xff09; 2.6 超…

谷歌手機刷機和面具ROOT保姆級別教程

#比較常用的谷歌輸入root面具教程,逆向工程師必修課程# 所需工具與材料清單 真機設備 推薦使用 Google Pixel 4 或其他兼容設備&#xff0c;確保硬件支持刷機操作。 ADB 環境配置 通過安裝 Android Studio 自動配置 ADB 和 Fastboot 工具。安裝完成后&#xff0c;需在系統環境…

平衡二叉搜索樹 - 紅黑樹詳解

文章目錄一、紅黑樹概念引申問題二、紅黑樹操作一、紅黑樹概念 紅黑樹是一棵二叉搜索樹&#xff0c;它在每個節點上增加了一個存儲位用來表示節點顏色(紅色或者黑色)&#xff0c;紅黑樹通過約束顏色&#xff0c;可以保證最長路徑不超過最短路徑的兩倍&#xff0c;因而近似平衡…

從0開始跟小甲魚C語言視頻使用linux一步步學習C語言(持續更新)8.14

第十六天 第五十二&#xff0c;五十三&#xff0c;五十四&#xff0c;五十五和五十六集 第五十二集 文件包含 一個include命令只能指定一個被包含文件 文件允許嵌套&#xff0c;就是一個被包含的文件可以包含另一個文件。 文件名可以用尖括號或者雙引號括起來 但是兩種的查找方…

B+樹索引分析:單表最大存儲記錄數

在現代數據庫設計中&#xff0c;隨著數據量的增加&#xff0c;如何有效地管理和優化數據庫成為了一個關鍵問題。根據阿里巴巴開發手冊的標準&#xff0c;當一張表預計在三年內的數據量超過500萬條或者2GB時&#xff0c;就應該考慮實施分庫分表策略 Mysql B樹索引介紹 及 頁內儲…

三、memblock 內存分配器

兩個問題&#xff1a; 1、系統是怎么知道物理內存的&#xff1f;linux內存管理學習&#xff08;1&#xff09;&#xff1a;物理內存探測 2、在內存管理真正初始化之前&#xff0c;內核的代碼執行需要分配內存該怎么處理&#xff1f; 在Linux內核啟動初期&#xff0c;完整的內存…

Python 桌面應用形態后臺管理系統的技術選型與方案報告

下面是一份面向“Python 桌面應用形態的后臺管理系統”的技術選型與方案報告。我把假設前提→總體架構→客戶端技術選型→服務端與數據層→基礎設施與安全→交付與運維→質量保障→里程碑計劃→風險與對策→最小可行棧逐層給出。 一、前置假設 & 非功能目標 業務假設 典型…

Winsows系統去除右鍵文件顯示的快捷列表

前言&#xff1a;今天重做了電腦系統&#xff0c;安裝的是純凈版的系統。然后手動指定D盤安裝了下列軟件。&#xff08;QQ&#xff0c;迅雷&#xff0c;百度網盤&#xff0c;搜狗輸入法&#xff0c;驅動精靈&#xff09;然后我右鍵點擊桌面的軟件快捷方式&#xff0c;出現了一排…

【Go】Gin 超時中間件的坑:fatal error: concurrent map writes

Gin 社區超時中間件的坑&#xff1a;導致線上 Pod 異常重啟 在最近的項目中&#xff0c;我們遇到了因為 Gin 超時中間件&#xff08;timeout&#xff09; 引發的生產事故&#xff1a;Pod 異常退出并重啟。 問題現場 pod無故重啟&#xff0c;抓取標準輸出日志&#xff0c;問題…