python-pytorch 下批量seq2seq+Bahdanau Attention實現問答1.0.000

python-pytorch 下批量seq2seq+Bahdanau Attention實現簡單問答1.0.000

    • 前言
    • 原理看圖
    • 數據準備
    • 分詞、index2word、word2index、vocab_size
    • 輸入模型的數據構造
    • 注意力模型
    • decoder的編寫
    • 關于損失函數和優化器
    • 在預測時
    • 完整代碼
    • 參考

前言

前面實現了 luong的dot 、general、concat注意力實現簡單問答,這里參考官方文檔,實現了python-pytorch 下批量seq2seq+Bahdanau Attention實現問答

原理看圖

在這里插入圖片描述
這里模型選擇和官方不一樣,官方選擇的是GRU,我更喜歡使用LSTM,解碼器和編碼器都是如此。
意思大致思路是:

  1. 計算encoder的encoder_outputs、encoder_hn、encoder_cn
  2. 使用encoder_outputs、encoder_hn計算新的向量和注意力
  3. 在deconder中,以SOS單字開始,循環句子最大長度,在循環中,使用新的向量和單字SOS做cat計算得到decoder的LSTM輸入數據,將該LSTM存起來,最后做cat計算得到decoder的輸出

數據準備

結果類似還是采用前面的結構和數據

seq_example = [“你認識我嗎”, “你住在哪里”, “你知道我的名字嗎”, “你是誰”, “你會唱歌嗎”, “誰是張學友”]
seq_answer = [“當然認識”, “我住在成都”, “我不知道”, “我是機器人”, “我不會”, “她旁邊那個就是”]

分詞、index2word、word2index、vocab_size

分詞然后做基礎準備,包括數據:index2word、word2index、vocab_size、最長的句子長度seq_length,和一些超參數的設置

輸入模型的數據構造

  1. 長度要統一
  2. 問答的句子以EOS結尾,不足補0,如

tensor([[ 3, 4, 5, 6, 2, 0, 0],
[ 3, 7, 8, 9, 2, 0, 0],
[ 3, 10, 5, 11, 12, 6, 2],
[ 3, 13, 14, 2, 0, 0, 0],
[ 3, 15, 16, 6, 2, 0, 0],
[14, 13, 17, 2, 0, 0, 0]])

注意力模型

可以復用,用官方的即可

# Bahdanau
# query=hidden [layer_num,batch_size,hidden_size] keys=encoder_outputs  [seq_len,batch_size,hidden_size]
class Attention(nn.Module):def __init__(self):super(Attention, self).__init__()self.Wa = nn.Linear(hidden_size, hidden_size)self.Ua = nn.Linear(hidden_size, hidden_size)self.Va = nn.Linear(hidden_size, 1)def forward(self, query, keys):scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) #[seq_len,batch_size,1]scores = scores.permute(1,0,2).squeeze(2).unsqueeze(1)#[batch_size,1,seq_len]weights = nn.functional.softmax(scores, dim=-1)#[batch_size,1,seq_len]context = torch.bmm(weights, keys.permute(1,0,2))#[batch_size,1,hidden_size]return context, weights

decoder的編寫

思路是,獲得encoder的輸出和hn后,計算得到向量,然后使用向量和目標的每一個字做cat計算,輸入decoder的模型中,然后得出一個字的預測,循環完了以后,就會得到最大句子長度,最后做cat和softmax計算得到輸出。另外,這里要區分訓練和測試,訓練的時候有target,測試的沒有target數據。

關于損失函數和優化器

NLLLoss+Adam的組合優于CrossEntropyLoss+SGD的組合

在預測時

獲取到模型輸出,size是[batch_size,seq_len,vocab_size]后,對結果做topk計算,會得到每一字在vocab_size的概率,連接起來就是一句話

完整代碼

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask# seq_answer,seq_example=getAQ()import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdmseq_example = ["你認識我嗎", "你住在哪里", "你知道我的名字嗎", "你是誰", "你會唱歌嗎", "誰是張學友"]
seq_answer = ["當然認識", "我住在成都", "我不知道", "我是機器人", "我不會", "她旁

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/15571.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/15571.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/15571.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【話題】我眼神的IT行業現狀與未來趨勢

目錄 一、挑戰 教學資源的重新分配 教師角色的轉變 學生學習方式的改變 教育評價體系的挑戰 二、機遇 個性化學習 跨學科學習 國際合作與交流 創新教育模式 三、如何培養下一代IT專業人才 更新教育理念 加強基礎設施建設 整合課程資源 加強實踐教學 培養跨學科…

easy-es EsAutoConfiguration RestHighLevelClient 沒有自動注入配置

我用的easy-es.version 是 2.0.0-beta1,是基于springboot2開發的,自動注入配置的目錄掃描的是META-INF/spring.factories文件;而我使用的框架是springboot3,springboot3掃描的是META-INF/spring/org.springframework.boot.autocon…

【算法刷題day57】Leetcode:739. 每日溫度、496.下一個更大元素 I

文章目錄 Leetcode 739. 每日溫度解題思路代碼總結 Leetcode 496.下一個更大元素 I解題思路代碼總結 草稿圖網站 java的Deque Leetcode 739. 每日溫度 題目:739. 每日溫度 解析:代碼隨想錄解析 解題思路 維護一個單調棧,當新元素大于棧頂&a…

【Linux】TCP協議【中】{確認應答機制/超時重傳機制/連接管理機制}

文章目錄 1.確認應答機制2.超時重傳機制:超時不一定是真超時了3.連接管理機制 1.確認應答機制 TCP協議中的確認應答機制是確保數據可靠傳輸的關鍵部分。以下是該機制的主要步驟和特點的詳細解釋: 數據分段與發送: 發送方將要發送的數據分成一…

vue深度選擇器(:deep?)

處于 scoped 樣式中的選擇器如果想要做更“深度”的選擇&#xff0c;也即&#xff1a;影響到子組件&#xff0c;可以使用 :deep() 這個偽類&#xff1a; <style lang"scss" scoped> .evaluation-situation-details :deep .cl-icon-arrow-right {display: none…

【Python 對接QQ的接口(二)】簡單用接口查詢【等級/昵稱/頭像/Q齡/當天在線時長/下一個等級升級需多少天】

文章日期&#xff1a;2024.05.25 使用工具&#xff1a;Python 類型&#xff1a;QQ接口 文章全程已做去敏處理&#xff01;&#xff01;&#xff01; 【需要做的可聯系我】 AES解密處理&#xff08;直接解密即可&#xff09;&#xff08;crypto-js.js 標準算法&#xff09;&…

JS根據所選ID數組在源數據中取出對象

let selectIds [1, 3] // 選中id數組let allData [{ id: 1, name: 123 },{ id: 2, name: 234 },{ id: 3, name: 345 },{ id: 4, name: 456 },] // 源數據let newList [] // 最終數據selectIds.map((i) > {allData.filter((item) > {item.id i && newList.pus…

websocket的壓縮和wireshark如何解碼tls

1. websocket的壓縮 見Compression EXPERIMENTAL那一節。 官方文檔&#xff1a;gorilla/websocket 2. 如何wireshark如何解碼tls 下文中代碼中去掉sudo。正常執行 Mac電腦安裝配置Wireshark 抓包工具&#xff0c;解決Https無法抓包問題_mac winshark抓不到-CSDN博客 如果…

aws sqs基礎概念和隊列參數解析

分布式隊列的組成部分 生產者&#xff0c;向隊列發送消息的組件消費者&#xff0c;接受隊列消息隊列&#xff0c;多個sqs服務器存儲冗余存儲消息 sqs自動刪除超過最大留存時間的消息&#xff08;默認4天&#xff09;&#xff0c;可以通過SetQueueAttributes調整為&#xff08…

【408真題】2009-13

“接”是針對題目進行必要的分析&#xff0c;比較簡略&#xff1b; “化”是對題目中所涉及到的知識點進行詳細解釋&#xff1b; “發”是對此題型的解題套路總結&#xff0c;并結合歷年真題或者典型例題進行運用。 涉及到的知識全部來源于王道各科教材&#xff08;2025版&…

JMH 微基準測試(性能測試)

寫本文主要是簡單記錄一下JMH的使用方式。JMH全名是Java Microbenchmark Harness&#xff0c;主要為在jvm上運行的程序進行基準測試的工具。作為一個開發人員&#xff0c;在重構代碼&#xff0c;或者確認功能的性能時&#xff0c;可以選中這個工具。 本文場景&#xff1a;代碼重…

VBA即用型代碼手冊:刪除Excel中空白行Delete Blank Rows in Excel

我給VBA下的定義&#xff1a;VBA是個人小型自動化處理的有效工具。可以大大提高自己的勞動效率&#xff0c;而且可以提高數據的準確性。我這里專注VBA,將我多年的經驗匯集在VBA系列九套教程中。 作為我的學員要利用我的積木編程思想&#xff0c;積木編程最重要的是積木如何搭建…

IDEA中好用的插件

IDEA中好用的插件 CodeGeeXMybatis Smart Code Help ProAlibaba Java Coding Guidelines?(XenoAmess TPM)?通義靈碼常用操作 TranslationStatistic CodeGeeX 官網地址&#xff1a;https://codegeex.cn/ 使用手冊&#xff1a;https://zhipu-ai.feishu.cn/wiki/CuvxwUDDqiErQU…

Android 自定義圖片進度條

用系統的Progressbar&#xff0c;設置圖片drawable作為進度條會出現圖片長度不好控制&#xff0c;容易被截斷&#xff0c;或者變形的問題。而我有個需求&#xff0c;使用圖片背景&#xff0c;和圖片進度&#xff0c;而且在進度條頭部有個閃光點效果。 如下圖&#xff1a; 找了…

速盾:流量攻擊防護DDOS有哪幾種有效的防御措施?

DDoS&#xff08;分布式拒絕服務&#xff09;攻擊是一種網絡攻擊方式&#xff0c;攻擊者通過向目標服務器發送大量的請求&#xff0c;超出其處理能力&#xff0c;導致服務器無法正常運行&#xff0c;從而使服務中斷或降級。為了保護網絡安全&#xff0c;減少DDoS攻擊對網站和服…

Kafka(十三)監控與告警

目錄 Kafka監控與告警1 解決方案1.2 基礎知識JMX監控指標代理查看KafkaJMX遠程端口 1.3 真實案例Kafka Exporter:PromethusPromethus Alert ManagerGrafana 1.3 實際操作部署監控和告警系統1.2.1 部署Kafka Exporter1.2.2 部署Prometheus1.2.3 部署AlertManger1.2.4 添加告警規…

大疆上云API本地部署與飛機上云

文章目錄 前言一、安裝基礎環境1. EMQX 安裝(版本4.4.0)2. MySql 安裝(版本8.0.26)3. Redis 安裝 二、部署后端&#xff08;JDK必須11及以上&#xff09;三、部署前端四、成為大疆開發者五、飛機注冊上云六、綁定飛機七、無人機狀態查看八、直播流查看 前言 大疆上云API官方文…

HarmonyOS鴻蒙應用開發——ArkTS的“內置組件 + 樣式 + 循環和條件渲染”

一、內置組件是咩&#xff1f; 學過前端的都知道&#xff0c;一個組件就是由多個組件組成的&#xff0c;一個組件也可以是多個小組件組成的&#xff0c;組件就是一些什么導航欄、底部、按鈕......啥的&#xff0c;但是組件分為【自定義組件】跟【內置組件】 【自定義組件】就…

Web開發核心

文章目錄 1.http協議簡介2.http協議特性3.http請求和響應協議4.最簡單的Web程序5.基于flask搭建web?站6.瀏覽器開發者?具&#xff08;重點&#xff09; 1.http協議簡介 HTTP協議是Hyper Text Transfer Protocol(超文本傳輸協議)的縮寫&#xff0c;是用于 萬維網(WWW:Norld W…

【狂神說Java】Redis筆記以及拓展

一、Redis 入門 Redis為什么單線程還這么快&#xff1f; 誤區1&#xff1a;高性能的服務器一定是多線程的&#xff1f; 誤區2&#xff1a;多線程&#xff08;CPU上下文會切換&#xff01;&#xff09;一定比單線程效率高&#xff01; 核心&#xff1a;Redis是將所有的數據放在內…