【現代深度學習技術】注意力機制04:Bahdanau注意力

在這里插入圖片描述

【作者主頁】Francek Chen
【專欄介紹】 ? ? ?PyTorch深度學習 ? ? ? 深度學習 (DL, Deep Learning) 特指基于深層神經網絡模型和方法的機器學習。它是在統計機器學習、人工神經網絡等算法模型基礎上,結合當代大數據和大算力的發展而發展出來的。深度學習最重要的技術特征是具有自動提取特征的能力。神經網絡算法、算力和數據是開展深度學習的三要素。深度學習在計算機視覺、自然語言處理、多模態數據分析、科學探索等領域都取得了很多成果。本專欄介紹基于PyTorch的深度學習算法實現。
【GitCode】專欄資源保存在我的GitCode倉庫:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目錄

    • 一、模型
    • 二、定義注意力解碼器
    • 三、訓練
    • 小結


??序列到序列學習(seq2seq)中探討了機器翻譯問題:通過設計一個基于兩個循環神經網絡的編碼器-解碼器架構,用于序列到序列學習。具體來說,循環神經網絡編碼器將長度可變的序列轉換為固定形狀的上下文變量,然后循環神經網絡解碼器根據生成的詞元和上下文變量按詞元生成輸出(目標)序列詞元。然而,即使并非所有輸入(源)詞元都對解碼某個詞元都有用,在每個解碼步驟中仍使用編碼相同的上下文變量。有什么方法能改變上下文變量呢?

??我們試著找到靈感:在為給定文本序列生成手寫的挑戰中,Graves設計了一種可微注意力模型,將文本字符與更長的筆跡對齊,其中對齊方式僅向一個方向移動。受學習對齊想法的啟發,Bahdanau等人提出了一個沒有嚴格單向對齊限制的可微注意力模型。在預測詞元時,如果不是所有輸入詞元都相關,模型將僅對齊(或參與)輸入序列中與當前預測相關的部分。這是通過將上下文變量視為注意力集中的輸出來實現的。

一、模型

??下面描述的Bahdanau注意力模型將遵循序列到序列學習(seq2seq)中的相同符號表達。這個新的基于注意力的模型與序列到序列學習(seq2seq)中的模型相同,只不過其中式(3)中的上下文變量 c \mathbf{c} c在任何解碼時間步 t ′ t' t都會被 c t ′ \mathbf{c}_{t'} ct?替換。假設輸入序列中有 T T T個詞元,解碼時間步 t ′ t' t的上下文變量是注意力集中的輸出:
c t ′ = ∑ t = 1 T α ( s t ′ ? 1 , h t ) h t (1) \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t \tag{1} ct?=t=1T?α(st?1?,ht?)ht?(1) 其中,時間步 t ′ ? 1 t' - 1 t?1時的解碼器隱狀態 s t ′ ? 1 \mathbf{s}_{t' - 1} st?1?是查詢,編碼器隱狀態 h t \mathbf{h}_t ht?既是鍵,也是值,注意力權重 α \alpha α是使用加性注意力打分函數計算的。

??與循環神經網絡編碼器-解碼器架構略有不同,圖1描述了Bahdanau注意力的架構。

在這里插入圖片描述

圖1 一個帶有Bahdanau注意力的循環神經網絡編碼器-解碼器模型

import torch
from torch import nn
from d2l import torch as d2l

二、定義注意力解碼器

??下面看看如何定義Bahdanau注意力,實現循環神經網絡編碼器-解碼器。其實,我們只需重新定義解碼器即可。為了更方便地顯示學習的注意力權重,以下AttentionDecoder類定義了帶有注意力機制解碼器的基本接口。

#@save
class AttentionDecoder(d2l.Decoder):"""帶有注意力機制解碼器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

??接下來,讓我們在接下來的Seq2SeqAttentionDecoder類中實現帶有Bahdanau注意力的循環神經網絡解碼器。首先,初始化解碼器的狀態,需要下面的輸入:

  1. 編碼器在所有時間步的最終層隱狀態,將作為注意力的鍵和值;
  2. 上一時間步的編碼器全層隱狀態,將作為初始化解碼器的隱狀態;
  3. 編碼器有效長度(排除在注意力池中填充詞元)。

??在每個解碼時間步驟中,解碼器上一個時間步的最終層隱狀態將用作查詢。因此,注意力輸出和輸入嵌入都連結為循環神經網絡解碼器的輸入。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs的形狀為(batch_size,num_steps,num_hiddens).# hidden_state的形狀為(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形狀為(batch_size,num_steps,num_hiddens).# hidden_state的形狀為(num_layers,batch_size,# num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 輸出X的形狀為(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形狀為(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# context的形狀為(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# 在特征維度上連結x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 將x變形為(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)# 全連接層變換后,outputs的形狀為# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

??接下來,使用包含7個時間步的4個序列輸入的小批量測試Bahdanau注意力解碼器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

在這里插入圖片描述

三、訓練

??與序列到序列學習(seq2seq)類似,我們在這里指定超參數,實例化一個帶有Bahdanau注意力的編碼器和解碼器,并對這個模型進行機器翻譯訓練。由于新增的注意力機制,訓練要序列到序列學習(seq2seq)比沒有注意力機制的慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在這里插入圖片描述
在這里插入圖片描述

??模型訓練后,我們用它將幾個英語句子翻譯成法語并計算它們的BLEU分數。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

在這里插入圖片描述

attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

??訓練結束后,下面通過可視化注意力權重會發現,每個查詢都會在鍵值對上分配不同的權重,這說明在每個解碼步中,輸入序列的不同部分被選擇性地聚集在注意力池中。

# 加上一個包含序列結束詞元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在這里插入圖片描述

小結

  • 在預測詞元時,如果不是所有輸入詞元都是相關的,那么具有Bahdanau注意力的循環神經網絡編碼器-解碼器會有選擇地統計輸入序列的不同部分。這是通過將上下文變量視為加性注意力池化的輸出來實現的。
  • 在循環神經網絡編碼器-解碼器中,Bahdanau注意力將上一時間步的解碼器隱狀態視為查詢,在所有時間步的編碼器隱狀態同時視為鍵和值。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80200.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80200.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80200.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

爬蟲學習————開始

🌿自動化的思想 任何領域的發展原因————“不斷追求生產方式的改革,即使得付出與耗費精力越來愈少,而收獲最大化”。由此,創造出方法和設備來提升效率。 如新聞的5W原則直接讓思考過程規范化、流程化。或者前端框架/后端輪子的…

每天五分鐘機器學習:KTT條件

本文重點 在前面的課程中,我們學習了拉格朗日乘數法求解等式約束下函數極值,如果約束不是等式而是不等式呢?此時就需要KTT條件出手了,KTT條件是拉格朗日乘數法的推廣。KTT條件不僅統一了等式約束與不等式約束的優化問題求解范式,KTT條件給出了這類問題取得極值的一階必要…

leetcode0829. 連續整數求和-hard

1 題目: 連續整數求和 官方標定難度:難 給定一個正整數 n,返回 連續正整數滿足所有數字之和為 n 的組數 。 示例 1: 輸入: n 5 輸出: 2 解釋: 5 2 3,共有兩組連續整數([5],[2,3])求和后為 5。 示例 2: 輸入: n 9 輸出: …

window 顯示驅動開發-線性伸縮空間段

線性伸縮空間段類似于線性內存空間段。 但是,伸縮空間段只是地址空間,不能容納位。 若要保存位,必須分配系統內存頁,并且必須重定向地址空間范圍以引用這些頁面。 內核模式顯示微型端口驅動程序(KMD)必須實…

Cadence 高速系統設計流程及工具使用三

5.8 約束規則的應用 5.8.1 層次化約束關系 在應用約束規則之前,我們首先要了解這些約束規則是如何作用在 Cadence 設計對象上的。Cadence 中對設計對象的劃分和概念,如表 5-11 所示。 在 Cadence 系統中,把設計對象按層次進行了劃分&#…

ScaleTransition 是 Flutter 中的一個動畫組件,用于實現縮放動畫效果。

ScaleTransition 是 Flutter 中的一個動畫組件,用于實現縮放動畫效果。它允許你對子組件進行動態的縮放變換,從而實現平滑的動畫效果。ScaleTransition 通常與 AnimationController 和 Tween 一起使用,以控制動畫的開始、結束和過渡效果。 基…

深入解析:如何基于開源p-net快速開發Profinet從站服務

一、Profinet協議與軟協議棧技術解析 1.1 工業通信的"高速公路" Profinet作為工業以太網協議三巨頭之一,采用IEEE 802.3標準實現實時通信,具有: 實時分級:支持RT(實時)和IRT(等時實時)通信模式拓撲靈活:支持星型、樹型、環型等多種網絡結構對象模型:基于…

m個n維向量組中m,n的含義與空間的關系

向量的維度與空間的關系&#xff1a; 一個向量的維度由其分量個數決定&#xff0c;例如 ( n ) 個分量的向量屬于 Rn空間 。 向量組張成空間的維度&#xff1a; 當向量組有 ( m ) 個線性無關的 ( n ) 維向量時&#xff1a; 若 ( m < n )&#xff1a; 這些向量張成的是 Rn中的…

excel大表導入數據庫

前文介紹了數據量較小的excel表導入數據庫的方法&#xff0c;在數據量較大的情況下就不太適合了&#xff0c;一個是因為mysql命令的執行串長度有限制&#xff0c;二是node-xlsx這個模塊加載excel文件是整個文件全部加載到內存&#xff0c;在excel文件較大和可用內存受限的場景就…

Python 爬蟲基礎入門教程(超詳細)

一、什么是爬蟲&#xff1f; 網絡爬蟲&#xff08;Web Crawler&#xff09;&#xff0c;又稱網頁蜘蛛&#xff0c;是一種自動抓取互聯網信息的程序。爬蟲會模擬人的瀏覽行為&#xff0c;向網站發送請求&#xff0c;然后獲取網頁內容并提取有用的數據。 二、Python爬蟲的基本原…

Spring Security 深度解析:打造堅不可摧的用戶認證與授權系統

Spring Security 深度解析&#xff1a;打造堅不可摧的用戶認證與授權系統 一、引言 在當今數字化時代&#xff0c;構建安全可靠的用戶認證與授權系統是軟件開發中的關鍵任務。Spring Security 作為一款功能強大的 Java 安全框架&#xff0c;為開發者提供了全面的解決方案。本…

【物聯網】基于樹莓派的物聯網開發【1】——初識樹莓派

使用背景 物聯網開發從0到1研究&#xff0c;以樹莓派為基礎 場景介紹 系統學習Linux、Python、WEB全棧、各種傳感器和硬件 接下來程序貓將帶領大家進軍物聯網世界&#xff0c;從0開始入門研究樹莓派。 認識樹莓派 正面圖示&#xff1a; 1&#xff1a;樹莓派簡介 樹莓派…

第21節:深度學習基礎-激活函數比較(ReLU, Sigmoid, Tanh)

1. 引言 在深度學習領域,激活函數是神經網絡中至關重要的組成部分 它決定了神經元是否應該被激活以及如何將輸入信號轉換為輸出信號 激活函數為神經網絡引入了非線性因素,使其能夠學習并執行復雜的任務 沒有激活函數,無論神經網絡有多少層,都只能表示線性變換,極大地限…

Fiori學習專題三十:Routing and Navigation

實際上我們的頁面是會有多個的&#xff0c;并且可以在多個頁面之間跳轉&#xff0c;這節課就學習如何在不同頁面之間實現跳轉。 1.修改配置文件manifest.json&#xff0c;加入routing&#xff0c;包含三個部分&#xff0c;config,routes,targets; config &#xff1a; routerC…

【HarmonyOS NEXT+AI】問答05:ArkTS和倉頡編程語言怎么選?

在“HarmonyOS NEXTAI大模型打造智能助手APP(倉頡版)”課程里面&#xff0c;有學員提到了這樣一個問題&#xff1a; 鴻蒙的主推開發語言不是ArkTS嗎&#xff0c;本課程為什么使用的是倉頡編程語言&#xff1f; 這里就這位同學的問題&#xff0c;統一做下回復&#xff0c;以方便…

Booth Encoding vs. Non-Booth Multipliers —— 穿透 DC 架構看乘法器的底層博弈

目錄 &#x1f9ed; 前言 &#x1f331; 1. Non-Booth 乘法器的實現原理&#xff08;也叫常規乘法器&#xff09; &#x1f527; 構建方式 ?? 例子&#xff1a;4x4 Non-Booth 乘法器示意 &#x1f9f1; 硬件結構 ? 特點總結 ? 2. Booth Encoding&#xff08;布斯編碼…

GET請求如何傳復雜數組參數

背景 有個歷史項目&#xff0c;是GET請求&#xff0c;但是很多請求還是復雜參數&#xff0c;比如&#xff1a;參數是數組&#xff0c;且數組中每一個元素都是復雜的對象&#xff0c;這個時候怎么傳參數呢&#xff1f; 看之前請求直接是拼接在url后面 類似&items%5B0%5D.…

iOS App 安全性探索:源碼保護、混淆方案與逆向防護日常

iOS App 安全性探索&#xff1a;源碼保護、混淆方案與逆向防護日常 在 iOS 開發者的日常工作中&#xff0c;我們總是關注功能的完整性、性能的優化和UI的細節&#xff0c;但常常忽視了另一個越來越重要的問題&#xff1a;發布后的應用安全。 尤其是對于中小團隊或獨立開發者&…

A* (AStar) 尋路

//調用工具類獲取路線 let route AStarSearch.getRoute(start_point, end_point, this.mapFloor.map_point); map_point 是所有可走點的集合 import { _decorator, Component, Node, Prefab, instantiate, v3, Vec2 } from cc; import { oops } from "../../../../../e…

深度解析動態IP業務核心場景:從技術演進到行業實踐

引言&#xff1a;動態IP的技術演進與行業價值 在數字化轉型加速的今天&#xff0c;IP地址已從單純的網絡標識演變為支撐數字經濟的核心基礎設施。動態IP作為靈活高效的地址分配方案&#xff0c;正突破傳統認知邊界&#xff0c;在網絡安全防護、數據價值挖掘、全球業務拓展等領…