Seq2Seq - 編碼器(Encoder)和解碼器(Decoder)

本節實現一個簡單的 Seq2Seq(Sequence to Sequence)模型 的編碼器(Encoder)和解碼器(Decoder)部分。

?重點把握Seq2Seq 模型的整體工作流程

理解編碼器(Encoder)和解碼器(Decoder)代碼

本小節引入了nn.GRU API的調用,nn.GRU具體參數將在下一小節進行補充講解

1.?編碼器(Encoder

類定義
class Encoder(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_size):super().__init__()self.emb = nn.Embedding(vocab_size, embedding_dim)self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
  • vocab_size:輸入詞匯表的大小,即輸入序列中可能出現的不同單詞或標記的數量。

  • embedding_dim:嵌入層的維度,即每個單詞或標記被映射到的向量空間的維度。

  • hidden_size:GRU(門控循環單元)的隱藏狀態維度,決定了模型的內部狀態大小。

主要組件
  1. 嵌入層(nn.Embedding

    • 嵌入層會將輸入序列形狀轉換為?[batch_size, seq_len, embedding_dim] 的張量。

    • 這種映射是通過學習嵌入矩陣實現的,每個單詞索引對應嵌入矩陣中的一行。

  2. GRU(nn.GRU

    • embedding_dim 是 GRU 的輸入維度,hidden_size 是隱藏狀態的維度。

    • batch_first=True 表示輸入和輸出的張量的第一個維度是批量大小(batch_size),而不是序列長度(seq_len)。

前向傳播(forward
def forward(self, x):embs = self.emb(x) #batch * token * embedding_dimgru_out, hidden = self.rnn(embs) #batch * token * hidden_sizereturn gru_out, hidden
  • 輸入 x 是一個形狀為 [batch_size, seq_len] 的張量,表示一個批次的輸入序列。

  • embs 是嵌入層的輸出,形狀為 [batch_size, seq_len, embedding_dim]

  • gru_out 是 GRU 的輸出,形狀為 [batch_size, seq_len, hidden_size],表示每個時間步的隱藏狀態。

  • hidden 是 GRU 的最終隱藏狀態,形狀為 [1, batch_size, hidden_size]用于傳遞給解碼器。

?

2.?解碼器(Decoder)

類定義
class Decoder(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_size):super().__init__()self.emb = nn.Embedding(vocab_size, embedding_dim)self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
  • 解碼器的結構與編碼器類似,但它的作用是將編碼器生成的上下文向量(hidden)解碼為目標序列。

主要組件
  1. 嵌入層(nn.Embedding

    • 與編碼器類似,將目標序列的單詞索引映射到嵌入向量。

  2. GRU(nn.GRU

    • 與編碼器中的 GRU 類似,但其輸入是目標序列的嵌入向量,初始隱藏狀態是編碼器的最終隱藏狀態。

前向傳播(forward
def forward(self, x, hx):embs = self.emb(x)gru_out, hidden = self.rnn(embs, hx=hx) #batch * token * hidden_size# batch * token * hidden_size# 1 * token * hidden_sizereturn gru_out, hidden
  • 輸入 x 是目標序列的單詞索引,形狀為 [batch_size, seq_len]

  • hx 是編碼器的最終隱藏狀態,形狀為 [1, batch_size, hidden_size]作為解碼器的初始隱藏狀態。

  • embs 是目標序列的嵌入向量,形狀為 [batch_size, seq_len, embedding_dim]

  • gru_out 是解碼器 GRU 的輸出,形狀為 [batch_size, seq_len, hidden_size]

  • hidden 是解碼器 GRU 的最終隱藏狀態,形狀為 [1, batch_size, hidden_size]

3.?Seq2Seq 模型的整體工作流程?

  1. 編碼階段

    • 輸入序列通過編碼器的嵌入層,將單詞索引映射為嵌入向量。

    • 嵌入向量通過 GRU,生成每個時間步的隱藏狀態和最終的隱藏狀態(上下文向量)。

    • 最終隱藏狀態(hidden)作為編碼器的輸出,傳遞給解碼器。

  2. 解碼階段

    • 解碼器的初始隱藏狀態是編碼器的最終隱藏狀態。

    • 解碼器逐個生成目標序列的單詞,每次生成一個單詞后,將該單詞的嵌入向量作為下一次輸入,同時更新隱藏狀態。

    • 通過這種方式,解碼器逐步生成目標序列。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/77134.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/77134.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/77134.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Boot集成MinIO的詳細步驟

1. 安裝MinIO 使用Docker部署MinIO 拉取MinIO鏡像: docker pull minio/minio 這將從Docker Hub中獲取最新的MinIO鏡像。 創建目錄: mkdir -p /home/minio/config mkdir -p /home/minio/data 這些目錄將用于持久化MinIO的數據和配置文件 創建MinIO…

基于PLC的停車場車位控制系統的設計

2.1 設計目標 本課題為基于PLC的停車場車位控制系統來設計,在此將功能確定如下: 針對8個車位的停車場進行設計將停車場分為入口處,車位處、以及出口處三個部分;每個車位都有指示燈指示當前位置是否空閑,方便司機查找空…

微服務即時通信系統---(四)框架學習

目錄 ElasticSearch 介紹 安裝 安裝kibana ES客戶端安裝 頭文件包含和編譯時鏈接庫 ES核心概念 索引(Index) 類型(Type) 字段(Field) 映射(mapping) 文檔(document) ES對比MySQL Kibana訪問ES測試 創建索引庫 新增數據 查看并搜索數據 刪除索引 ES…

除了 `task_type=“SEQ_CLS“`(序列分類),還有CAUSAL_LM,QUESTION_ANS

task_type="SEQ_CLS"是什么意思:QUESTION_ANS 我是qwen,不同模型是不一樣的 SEQ_CLS, SEQ_2_SEQ_LM, CAUSAL_LM, TOKEN_CLS, QUESTION_ANS, FEATURE_EXTRACTION. task_type="SEQ_CLS" 通常用于自然語言處理(NLP)任務中,SEQ_CLS 是 Sequence Classif…

Android ViewPager使用預加載機制導致出現頁面穿透問題

? 緣由 在應用中使用ViewPager,并且設置預加載頁面。結果出現了一些異常的現象。 我們有4個頁面,分別是4個Fragment,暫且稱為FragmentA、FragmentB、FragmentC、FragmentD,ViewPager在MainActivity中,切換時&#x…

apt3.0和apt2.0的區別

一,簡單區別 更新方式 apt2.0:一次性更新所有內容,沒有分階段更新功能。apt3.0:引入分階段更新功能,可分批推送更新包。 界面顯示 apt2.0:界面簡單,輸出信息較為雜亂,沒有彩色高亮和…

過電壓保護器與傳統的保護方式對比

過電壓保護器主要用于保護電氣設備免受大氣過電壓(如雷擊)和操作過電壓(開關動作等引發)的侵害。它通常由非線性電阻片等元件組成,利用其獨特的伏安特性工作。正常電壓下,保護器呈現高阻態,幾乎…

機器學習(3)——決策樹

文章目錄 1. 決策樹基本原理1.1. 什么是決策樹?1.2. 決策樹的基本構成:1.3. 核心思想 2. 決策樹的構建過程2.1. 特征選擇2.1.1. 信息增益(ID3)2.1.2. 基尼不純度(CART)2.1.3. 均方誤差(MSE&…

充電樁領域垂直行業大模型分布式推理與訓練平臺建設方案 - 慧知開源充電樁平臺

沒有任何廣告! 充電樁領域垂直行業大模型分布式推理與訓練平臺建設方案 一、平臺定位與核心價值 行業首個垂直化AI平臺 專為充電樁運營場景設計的分布式大模型訓練與推理基礎設施,實現"算力-算法-場景"三位一體閉環管理。 核心價值主張&am…

NLP高頻面試題(四十五)——PPO 算法在 RLHF 中的原理與實現詳解

近端策略優化(Proximal Policy Optimization, PPO)算法是強化學習領域的一種新穎且高效的策略優化方法,在近年大規模語言模型的人類反饋強化學習(Reinforcement Learning with Human Feedback, RLHF)中發揮了關鍵作用。本文將以學術嚴謹的風格,詳細闡述 PPO 算法的原理及…

C++指針和引用之區別(The Difference between C++Pointers and References)

面試題:C指針和引用有什么區 C指針和引用有什么區別? 在 C 中,指針和引用都是用來訪問其他變量的值的方式,但它們之間存在一些重要的區別。了解這些區別有助于更好地理解和使用這兩種工具。 01 指針 指針(Pointer…

LWIP學習筆記

TCP/ip協議結構分層 傳輸層簡記 TCP:可靠性強,有重傳機制 UDP:單傳機制,不可靠 UDP在ip層分片 TCP在傳輸層分包 應用層傳輸層網絡層,構成LWIP內核程序: 鏈路層;由mac內核STM芯片的片上外設…

【經驗記錄貼】活用shell,提高工作效率

背景 最近在做測試的時候,需要手動kill服務的進程,然后通過命令重啟服務,再進行測試。每次重啟都會涉及到下面三個命令的執行: 1)檢索進程ID $ ps -eLf | grep programname root 1123 112 1234 0 0 0 0:00…

MacOS 系統下 Git 的詳細安裝步驟與基礎設置指南

MacOS 系統下 Git 的詳細安裝步驟與基礎設置指南—目錄 一、安裝 Git方法 1:通過 Homebrew 安裝(推薦)方法 2:通過 Xcode Command Line Tools 安裝方法 3:手動下載安裝包 二、基礎配置1. 設置全局用戶名和郵箱2. 配置 …

一文讀懂 AI

2022年11月30日,OpenAI發布了ChatGPT,2023年3月15日,GPT-4引發全球轟動,讓世界上很多人認識了ai這個詞。如今已過去快兩年半,AI產品層出不窮,如GPT-4、DeepSeek、Cursor、自動駕駛等,但很多人仍…

【教程】檢查RDMA網卡狀態和測試帶寬 | 附測試腳本

轉載請注明出處:小鋒學長生活大爆炸[xfxuezhagn.cn] 如果本文幫助到了你,歡迎[點贊、收藏、關注]哦~ 目錄 檢查硬件和驅動狀態 測試RDMA通信 報錯修復 對于交換機的配置,可以看這篇: 【教程】詳解配置多臺主機通過交換機實現互…

計算機網絡 - TCP協議

通過一些問題來討論 TCP 協議 什么是 TCP ?舉幾個應用了 TCP 協議的例子TCP協議如何保證可靠性?tcp如何保證不會接受重復的報文?Tcp粘包拆包問題了解嗎?介紹一下,如何解決?TCP擁塞控制與流量控制區別&…

Fiddler 進行斷點測試:調試網絡請求

目錄 一、什么是斷點測試? 二、Fiddler 的斷點功能 三、如何在 Fiddler 中設置斷點? 步驟 1:啟動 Fiddler 步驟 2:啟用斷點 步驟 3:捕獲請求 步驟 4:修改請求或響應 四、案例:模擬登錄失…

OpenCv高階(三)——圖像的直方圖、圖像直方圖的均衡化

目錄 一、直方圖 1、計算并顯示直方圖 2、使用matplotlib方法繪制直方圖(不劃分小的子區間) 3、使用opencv的方法繪制直方圖 (劃分16個小的子亮度區間) 4、繪制彩色圖像的直方圖,將各個通道的直方圖值都畫出來 二、…

Flutter 與原生通信

Flutter 與原生之間的通信主要基于通道機制,包括 MethodChannel、EventChannel 和 BasicMessageChannel。 MethodChannel:用于 Flutter 與原生之間的方法調用,實現雙向通信,適合一次性的方法調用并獲取返回值,如 Flut…