從代碼學習深度學習 - Bahdanau注意力 PyTorch版

文章目錄

    • 1. 前言
      • 為什么選擇Bahdanau注意力
      • 本文目標與預備知識
    • 2. Bahdanau注意力機制概述
      • 注意力機制簡述
      • 加性注意力與乘性注意力對比
      • Bahdanau注意力的數學原理與流程圖
        • 數學原理
        • 流程圖
        • 可視化與直觀理解
    • 3. 數據準備與預處理
      • 數據集簡介
      • 數據加載與預處理
        • 1. 讀取數據集
        • 2. 預處理文本
        • 3. 詞元化
      • 詞表構建
      • 序列截斷與填充
      • 構建張量與有效長度
      • 創建數據迭代器
      • 數據準備的關鍵點
      • 與Bahdanau注意力的關聯
      • 總結
    • 4. 模型組件搭建
      • 4.1 總體架構概述
      • 4.2 編碼器(Encoder)
      • 4.3 解碼器(Decoder)
      • 4.4 Bahdanau注意力機制(AdditiveAttention)
      • 4.5 屏蔽機制(sequence_mask 和 masked_softmax)
        • sequence_mask
        • masked_softmax
      • 4.6 數據加載與模型整合
      • 4.7 關鍵點與優勢
      • 4.8 可視化與驗證
      • 4.9 總結
  • 5. 訓練流程實現
    • 5.1 數據加載
    • 5.2 模型定義
    • 5.3 訓練過程
      • 5.3.1 權重初始化
      • 5.3.2 優化器和損失函數
      • 5.3.3 訓練循環
      • 5.3.4 訓練結果輸出
    • 5.4 預測與評估
      • 5.4.1 預測實現
      • 5.4.2 BLEU 分數評估
      • 5.4.3 注意力權重可視化
    • 5.5 實現亮點
    • 5.6 總結
  • 6. 模型推理與預測
    • 6.1 序列翻譯預測函數詳解
      • 6.1.1 函數定義與參數
      • 6.1.2 預處理階段
      • 6.1.3 編碼器前向傳播
      • 6.1.4 解碼器逐時間步預測
      • 6.1.5 輸出處理
      • 6.1.6 實現亮點
      • 6.1.7 潛在改進方向
    • 6.2 BLEU 評估指標解釋與實現
      • 6.2.1 BLEU 指標概述
      • 6.2.2 函數定義與參數
      • 6.2.3 計算邏輯與實現
        • 6.2.3.1 預處理
        • 6.2.3.2 長度懲罰
        • 6.2.3.3 n-gram 精確度
        • 6.2.3.4 返回結果
      • 6.2.4 BLEU 的意義與局限性
      • 6.2.5 實現亮點
      • 6.2.6 潛在改進方向
    • 6.3 總結
  • 7. 可視化注意力權重
    • 7.1 注意力熱圖繪制與分析
      • 7.1.1 代碼實現
      • 7.1.2 熱圖分析
      • 7.1.3 可視化效果
    • 7.2 模型關注詞元的可解釋性展示
      • 7.2.1 可解釋性意義
      • 7.2.2 可視化案例
      • 7.2.3 提升可解釋性的方法
    • 7.3 實現亮點
  • 8. 總結
    • 8.1 Bahdanau 注意力的實現經驗分享
    • 8.2 PyTorch 中模塊化建模的優勢
    • 8.3 下一步可以探索的方向
    • 8.4 總結


完整代碼:下載連接

1. 前言

為什么選擇Bahdanau注意力

在深度學習領域,尤其是自然語言處理(NLP)任務中,序列到序列(Seq2Seq)模型是許多應用的核心,如機器翻譯、文本摘要和對話系統等。傳統的Seq2Seq模型依賴于編碼器-解碼器架構,通過編碼器將輸入序列壓縮為固定長度的上下文向量,再由解碼器生成輸出序列。然而,這種方法在處理長序列時往往面臨信息丟失的問題,上下文向量難以捕捉輸入序列的全部細節。

Bahdanau注意力機制(Bahdanau et al., 2014)通過引入動態的上下文選擇機制,顯著提升了模型對輸入序列的利用效率。它允許解碼器在生成每個輸出時,動態地關注輸入序列的不同部分,而非依賴單一的上下文向量。這種機制不僅提高了翻譯質量,還為后續的注意力機制(如Transformer)奠定了基礎。選擇Bahdanau注意力作為學習對象,是因為它直觀地展示了注意力機制的核心思想,同時在實現上具有足夠的復雜度,能夠幫助我們深入理解深度學習的建模過程。

此外,PyTorch作為一個靈活且直觀的深度學習框架,非常適合實現和調試復雜的模型結構。通過本文的代碼分析,我們將以Bahdanau注意力為核心,結合PyTorch的模塊化編程,探索Seq2Seq模型的完整實現流程,為進一步學習Transformer等高級模型打下堅實基礎。

本文目標與預備知識

本文的目標是通過剖析一個基于PyTorch實現的Bahdanau注意力Seq2Seq模型,幫助讀者從代碼層面理解深度學習模型的設計與實現。我們將從數據預處理、模型組件搭建、訓練流程到推理與可視化,逐步拆解每個環節的核心代碼,揭示Bahdanau注意力機制的運作原理,并提供直觀的解釋和可視化結果。同時,通過模塊化代碼的分析,我們將展示如何在PyTorch中高效地組織復雜項目。

為了更好地理解本文內容,建議讀者具備以下預備知識:

  • Python編程基礎:熟悉Python語法、面向對象編程以及PyTorch的基本操作(如張量操作、模塊定義和自動求導)。
  • 深度學習基礎:了解神經網絡的基本概念(如前向傳播、反向傳播、損失函數和優化器),以及循環神經網絡(RNN)或門控循環單元(GRU)的工作原理。
  • NLP基礎:對詞嵌入(Word Embedding)、序列建模和機器翻譯任務有初步了解。
  • 數學基礎:熟悉線性代數(如矩陣運算)、概率論(softmax函數)以及基本的優化理論。

如果你對上述內容有所欠缺,不必擔心!本文將盡量通過代碼注釋和直觀的解釋,降低學習門檻,讓你能夠通過實踐逐步掌握Bahdanau注意力的精髓。

接下來,我們將進入Bahdanau注意力機制的詳細分析,從理論到代碼實現,帶你一步步走進深度學習的精彩世界!

2. Bahdanau注意力機制概述

注意力機制簡述

在深度學習領域,特別是在序列到序列(Seq2Seq)任務如機器翻譯中,注意力機制(Attention Mechanism)是一種革命性的技術,用于解決傳統Seq2Seq模型在處理長序列時的瓶頸問題。傳統Seq2Seq模型通過編碼器將輸入序列壓縮為一個固定長度的上下文向量,再由解碼器基于此向量生成輸出序列。然而,當輸入序列較長時,固定上下文向量難以充分捕捉所有輸入信息,導致信息丟失和翻譯質量下降。

注意力機制的提出,允許模型在生成輸出時動態地關注輸入序列的不同部分,而不是依賴單一的上下文向量。具體來說,注意力機制通過計算輸入序列每個位置與當前解碼步驟的相關性(注意力權重),為解碼器提供一個加權的上下文向量。這種動態聚焦的方式極大地提高了模型對長序列的建模能力,并增強了生成結果的可解釋性。

Bahdanau注意力(也稱為加性注意力,Additive Attention)是注意力機制的早期代表之一,首次提出于2014年的論文《Neural Machine Translation by Jointly Learning to Align and Translate》。它通過引入一個可學習的對齊模型,動態計算輸入序列與輸出序列之間的關聯,被廣泛應用于機器翻譯等任務。

加性注意力與乘性注意力對比

注意力機制根據計算注意力得分(Attention Score)的方式不同,可以分為加性注意力和乘性注意力(Dot-Product Attention)兩大類:

  • 加性注意力(Additive Attention)

    • 計算方式:Bahdanau注意力屬于加性注意力,其核心是通過將查詢(Query)和鍵(Key)映射到相同的隱藏維度后,相加并通過非線性激活函數(如tanh)處理,最后通過線性變換得到注意力得分。

    • 數學表達式
      score ( q , k i ) = w v ? ? tanh ? ( W q q + W k k i ) \text{score}(q, k_i) = w_v^\top \cdot \tanh(W_q q + W_k k_i) score(q,ki?)=wv???tanh(Wq?q+Wk?ki?)
      其中,(q)是查詢向量,(k_i)是鍵向量,(W_q)和(W_k)是可學習的權重矩陣,(w_v)是用于計算最終得分的權重向量。

    • 特點

      • 計算復雜度較高,因為需要對查詢和鍵進行線性變換并相加。
      • 適合查詢和鍵維度不同的場景,因為它通過映射統一了維度。
      • 在Bahdanau注意力中,注意力得分經過softmax歸一化,生成權重,用于加權求和值(Value)向量,形成上下文向量。
    • 代碼體現
      在提供的代碼中,AdditiveAttention類實現了這一過程:

      queries, keys = self.W_q(queries), self.W_k(keys)
      features = queries.unsqueeze(2) + keys.unsqueeze(1)
      features = torch.tanh(features)
      scores = self.w_v(features).squeeze(-1)
      self.attention_weights = masked_softmax(scores, valid_lens)
      
  • 乘性注意力(Dot-Product Attention)

    • 計算方式:乘性注意力通過查詢和鍵的點積直接計算得分,通常在查詢和鍵維度相同時使用。
    • 數學表達式
      score ( q , k i ) = q ? k i \text{score}(q, k_i) = q^\top k_i score(q,ki?)=q?ki?
      或其縮放版本(Scaled Dot-Product Attention):
      score ( q , k i ) = q ? k i d k \text{score}(q, k_i) = \frac{q^\top k_i}{\sqrt{d_k}} score(q,ki?)=dk? ?q?ki??
      其中, d k d_k dk?是鍵的維度,用于防止點積過大。
    • 特點
      • 計算效率較高,適合大規模并行計算,廣泛用于Transformer模型。
      • 假設查詢和鍵具有相同的維度,否則需要額外的映射。
      • 對于高維輸入,可能需要縮放以穩定訓練。
    • 適用場景
      乘性注意力在Transformer等現代模型中更為常見,但在Bahdanau注意力提出時,RNN-based的Seq2Seq模型更傾向于使用加性注意力,因為它能更好地處理變長序列和不同維度的輸入。

對比總結

  • 加性注意力(Bahdanau)通過顯式的非線性變換,靈活性更高,適合早期RNN模型,但計算開銷較大。
  • 乘性注意力(Luong或Transformer)計算簡單,效率高,適合現代GPU加速的場景,但在維度不匹配時需要額外處理。
  • Bahdanau注意力作為加性注意力的代表,為后續的乘性注意力機制奠定了理論基礎。

Bahdanau注意力的數學原理與流程圖

數學原理

Bahdanau注意力的核心目標是為解碼器的每個時間步生成一個上下文向量,該向量是輸入序列隱藏狀態的加權和,權重由注意力得分決定。其工作流程可以分解為以下步驟:

  1. 輸入

    • 編碼器輸出:編碼器(通常為GRU或LSTM)處理輸入序列,生成隱藏狀態序列 ( h 1 , h 2 , … , h T h_1, h_2, \dots, h_T h1?,h2?,,hT? ),其中 $T $ 是輸入序列長度,每個 h i h_i hi?是鍵(Key)和值(Value)。
    • 解碼器狀態:解碼器在時間步 t t t的隱藏狀態 s t s_t st?,作為查詢(Query)。
  2. 注意力得分計算

    • 對于解碼器狀態 s t s_t st? 和每個編碼器隱藏狀態 h i h_i hi?,計算注意力得分:
      e t , i = w v ? ? tanh ? ( W s s t + W h h i ) e_{t,i} = w_v^\top \cdot \tanh(W_s s_t + W_h h_i) et,i?=wv???tanh(Ws?st?+Wh?hi?)
      其中, W s W_s Ws? W h W_h Wh?是將查詢和鍵映射到隱藏維度的權重矩陣, w v w_v wv?是用于生成標量得分的權重向量。
  3. 注意力權重歸一化

    • 將得分通過softmax函數歸一化為權重:
      $\alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^T \exp(e_{t,j})}
      $
      其中, α t , i \alpha_{t,i} αt,i?表示時間步 t t t 對輸入位置 i i i的關注程度,滿足 ∑ i α t , i = 1 \sum_i \alpha_{t,i} = 1 i?αt,i?=1

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76619.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76619.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76619.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

19【動手學深度學習】卷積層

1. 從全連接到卷積 2. 圖像卷積 3. 圖形卷積代碼 互相關操作 import torch from torch import nn from d2l import torch as d2ldef corr2d(X, K):"""計算2維互相關運算"""h, w K.shapeY torch.zeros((X.shape[0]-h1, X.shape[1]-w 1))for …

Linux xorg-server 解析(一)- 編譯安裝Debug版本的xorg-server

一:下載代碼 1. 配置源,以Ubuntu24.04 為例( /etc/apt/sources.list.d/ubuntu.sources): 2. apt source xserver-xorg-core 二:編譯代碼 1. sudo apt build-dep ./ 2. DEB_BUILD_OPTIONS="nostrip" DEB_CFLAGS_SET="-g -O0" dpkg-buildpac…

大模型SFT用chat版還是base版 SFT后災難性遺忘怎么辦

大模型SFT用chat版還是base版 進行 SFT 時,基座模型選用 Chat 還是 Base 模型? 選 Base 還是 Chat 模型,首先先熟悉 Base 和 Chat 是兩種不同的大模型,它們在訓練數據、應用場景和模型特性上有所區別。 在訓練數據方面&#xf…

【圖像生成之21】融合了Transformer與Diffusion,Meta新作Transfusion實現圖像與語言大一統

論文:Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model 地址:https://arxiv.org/abs/2408.11039 類型:理解與生成 Transfusion模型?是一種將Transformer和Diffusion模型融合的多模態模型,旨…

動態多目標進化算法:基于知識轉移和維護功能的動態多目標進化算法(KTM-DMOEA)求解CEC2018(DF1-DF14)

一、KTM-DMOEA介紹 在實際工程和現實生活中,許多優化問題具有動態性和多目標性,即目標函數會隨著環境的變化而改變,并且存在多個相互沖突的目標。傳統的多目標進化算法在處理這類動態問題時面臨著一些挑戰,如收斂速度慢、難以跟蹤…

部署NFS版StorageClass(存儲類)

部署NFS版StorageClass存儲類 NFS版PV動態供給StorageClass(存儲類)基于NFS實現動態供應下載NFS存儲類資源清單部署NFS服務器為StorageClass(存儲類)創建所需的RBAC部署nfs-client-provisioner的deployment創建StorageClass使用存儲類創建PVC NFS版PV動態供給StorageClass(存儲…

Vue使用el-table給每一行數據上面增加一行自定義合并行

// template <template><el-table:data"flattenedData":span-method"objectSpanMethod"borderclass"custom-header-table"style"width: 100%"ref"myTable":height"60vh"><!-- 訂單詳情列 -->&l…

vue項目使用html2canvas和jspdf將頁面導出成PDF文件

一、需求&#xff1a; 頁面上某一部分內容需要生成pdf并下載 二、技術方案&#xff1a; 使用html2canvas和jsPDF插件 三、js代碼 // 頁面導出為pdf格式 import html2Canvas from "html2canvas"; import jsPDF from "jspdf"; import { uploadImg } f…

大模型LLM表格報表分析:markitdown文件轉markdown,大模型markdown統計分析

整體流程&#xff1a;用markitdown工具文件轉markdown&#xff0c;然后大模型markdown統計分析 markitdown https://github.com/microsoft/markitdown 在線體驗&#xff1a;https://huggingface.co/spaces/AlirezaF138/Markitdown 安裝&#xff1a; pip install markitdown…

Linux 第二講 --- 基礎指令(二)

前言 這是基礎指令的第二部分&#xff0c;但是該部分的講解會大量使用到基礎指令&#xff08;一&#xff09;的內容&#xff0c;為了大家的觀感&#xff0c;如果對Linux的一些基本指令不了解的話&#xff0c;可以先看基礎指令&#xff08;一&#xff09;&#xff0c;同樣的本文…

python格式化字符串漏洞

什么是python格式化字符串漏洞 python中&#xff0c;存在幾種格式化字符串的方式&#xff0c;然而當我們使用的方式不正確的時候&#xff0c;即格式化的字符串能夠被我們控制時&#xff0c;就會導致一些嚴重的問題&#xff0c;比如獲取敏感信息 python常見的格式化字符串 百…

LLaMA-Factory雙卡4090微調DeepSeek-R1-Distill-Qwen-14B醫學領域

unsloth單卡4090微調DeepSeek-R1-Distill-Qwen-14B醫學領域后&#xff0c;跑通一下多卡微調。 1&#xff0c;準備2卡RTX 4090 2&#xff0c;準備數據集 醫學領域 pip install -U huggingface_hub export HF_ENDPOINThttps://hf-mirror.com huggingface-cli download --resum…

React Hooks: useRef,useCallback,useMemo用法詳解

1. useRef&#xff08;保存引用值&#xff09; useRef 通常用于保存“不會參與 UI 渲染&#xff0c;但生命周期要長”的對象引用&#xff0c;比如獲取 DOM、保存定時器 ID、WebSocket等。 新建useRef.js組件&#xff0c;寫入代碼&#xff1a; import React, { useRef, useSt…

Spring AI 結構化輸出詳解

一、Spring AI 結構化輸出的定義與核心概念 Spring AI 提供了一種強大的功能&#xff0c;允許開發者將大型語言模型&#xff08;LLM&#xff09;的輸出從字符串轉換為結構化格式&#xff0c;如 JSON、XML 或 Java 對象。這種結構化輸出能力對于依賴可靠解析輸出值的下游應用程…

THM Billing

1. 信息收集 (1) Nmap 掃描 bashnmap -T4 -sC -sV -p- 10.10.189.216 輸出關鍵信息&#xff1a; PORT STATE SERVICE VERSION22/tcp open ssh OpenSSH 8.4p1 Debian 5deb11u380/tcp open http Apache 2.4.56 (Debian) # MagnusBilling 應用3306/tcp open …

布局決定終局:基于開源AI大模型、AI智能名片與S2B2C商城小程序的戰略反推思維

摘要&#xff1a;在商業競爭日益激烈的當下&#xff0c;布局與終局預判成為企業成功的關鍵要素。本文探討了布局與終局預判的智慧性&#xff0c;強調其雖無法做到百分之百準確&#xff0c;但能顯著提升思考能力。終局思維作為重要戰略工具&#xff0c;并非一步到位的戰略部署&a…

貪心算法 day08(加油站+單調遞增的數字+壞了的計算機)

目錄 1.加油站 2.單調遞增的數字 3.壞了的計算器 1.加油站 鏈接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; gas[index] - cost[index]&#xff0c;ret 表示的是在i位置開始循環時剩余的油量 a到達的最大路徑假設是f那么我們可以得出 a b …

【技術派部署篇】云服務器部署技術派

1 環境搭建 1.1 JDK安裝 # ubuntu sudo apt update # 更新apt apt install openjdk-8-jdk # 安裝JDK安裝完畢之后&#xff0c;執行 java -version 命令進行驗證&#xff1a; 1.2 Maven安裝 cd ~ mkdir soft cd soft wget https://dlcdn.apache.org/maven/maven-3/3.8.8/bina…

Linux:35.其他IPC和IPC原理+信號量入門

通過命名管道隊共享內存的數據發送進行保護的bug&#xff1a; 命名管道掛掉后&#xff0c;進程也掛掉了。 6.systemV消息隊列 原理:進程間IPC:原理->看到同一份資源->維護成為一個隊列。 過程&#xff1a; 進程A,進程B進行通信。 讓操作系統提供一個隊列結構&#xff0c;…

【數據結構】紅黑樹超詳解 ---一篇通關紅黑樹原理(含源碼解析+動態構建紅黑樹)

一.什么是紅黑樹 紅黑樹是一種自平衡的二叉查找樹&#xff0c;是計算機科學中用到的一種數據結構。1972年出現&#xff0c;最初被稱為平衡二叉B樹。1978年更名為“紅黑樹”。是一種特殊的二叉查找樹&#xff0c;紅黑樹的每一個節點上都有存儲表示節點的顏色。每一個節點可以是…