NLP Seq2Seq模型

🍨 本文為[🔗365天深度學習訓練營學習記錄博客🍦 參考文章:365天深度學習訓練營🍖 原作者:[K同學啊 | 接輔導、項目定制]\n🚀 文章來源:[K同學的學習圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

Seq2Seq模型是一種深度學習模型,用于處理序列到序列的任務,它由兩個主要部分組成:編碼器(Encoder)和解碼器(Decoder)。

  1. 編碼器(Encoder): 編碼器負責將輸入序列(例如源語言句子)編碼成一個固定長度的向量,通常稱為上下文向量或編碼器的隱藏狀態。編碼器可以是循環神經網絡(RNN)、長短期記憶網絡(LSTM)或者變種如門控循環單元(GRU)等。編碼器的目標是捕捉輸入序列中的語義信息,并將其編碼成一個固定維度的向量表示。

  2. 解碼器(Decoder): 解碼器接收編碼器生成的上下文向量,并根據它來生成輸出序列(例如目標語言句子)。解碼器也可以是RNN、LSTM、GRU等。在訓練階段,解碼器一次生成一個詞或一個標記,并且其隱藏狀態從一個時間步傳遞到下一個時間步。解碼器的目標是根據上下文向量生成與輸入序列對應的輸出序列。

在訓練階段,Seq2Seq模型的目標是最大化目標序列的條件概率給定輸入序列。為了實現這一點,通常使用了一種稱為教師強制(Teacher Forcing)的技術,即將目標序列中的真實標記作為解碼器的輸入。但是,在推理階段(即模型用于生成新的序列時),解碼器則根據先前生成的標記生成下一個標記,直到生成一個特殊的終止標記或達到最大長度為止。

下面演示了如何使用PyTorch實現一個簡單的Seq2Seq模型,用于將一個序列翻譯成另一個序列。這里我們將使用一個虛構的數據集來進行簡單的法語到英語翻譯。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset# 定義數據集
class SimpleDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 定義Encoder
class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hidden_dim):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)def forward(self, src):embedded = self.embedding(src)outputs, hidden = self.rnn(embedded)return outputs, hidden# 定義Decoder
class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hidden_dim):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, output_dim)def forward(self, input, hidden):input = input.unsqueeze(0)embedded = self.embedding(input)output, hidden = self.rnn(embedded, hidden)prediction = self.fc_out(output.squeeze(0))return prediction, hidden# 定義Seq2Seq模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = trg.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)encoder_outputs, hidden = self.encoder(src)input = trg[0,:]for t in range(1, trg_len):output, hidden = self.decoder(input, hidden)outputs[t] = outputteacher_force = np.random.rand() < teacher_forcing_ratiotop1 = output.argmax(1) input = trg[t] if teacher_force else top1return outputs# 設置參數
INPUT_DIM = 10
OUTPUT_DIM = 10
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
HID_DIM = 64
N_LAYERS = 1
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5# 實例化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Seq2Seq(enc, dec, device).to(device)# 打印模型結構
print(model)# 定義訓練函數
def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)# 定義測試函數
def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for i, batch in enumerate(iterator):src, trg = batchsrc = src.to(device)trg = trg.to(device)output = model(src, trg, 0) # 關閉teacher forcingoutput_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)# 示例數據
train_data = [(torch.tensor([1, 2, 3]), torch.tensor([3, 2, 1])),(torch.tensor([4, 5, 6]), torch.tensor([6, 5, 4])),(torch.tensor([7, 8, 9]), torch.tensor([9, 8, 7]))]# 超參數
BATCH_SIZE = 3
N_EPOCHS = 10
LEARNING_RATE = 0.001
CLIP = 1# 數據集與迭代器
train_dataset = SimpleDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)# 定義損失函數與優化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()# 訓練模型
for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/711927.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/711927.shtml
英文地址,請注明出處:http://en.pswp.cn/news/711927.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

深入理解Linux線程(LWP):概念、結構與實現機制(2)

&#x1f3ac;慕斯主頁&#xff1a;修仙—別有洞天 ??今日夜電波&#xff1a;會いたい—Naomile 1:12━━━━━━?&#x1f49f;──────── 4:59 &#x1f504; ?? ? ?? ? &a…

Vue3+vite打包后頁面空白問題

vite.config.js vite.config.js 增加 base: ./ import { fileURLToPath, URL } from node:url import { defineConfig } from vite import vue from vitejs/plugin-vue// https://vitejs.dev/config/ export default defineConfig({base: ./,resolve: {alias: {: fileURLToPath…

解析短視頻美顏SDK:美顏美型技術的深度剖析

美顏并非簡單的濾鏡疊加&#xff0c;而是依托著先進的圖像處理和人工智能技術&#xff0c;才能夠達到如此出色的效果。本文將深入探討短視頻美顏SDK背后的技術原理和實現方法&#xff0c;從而揭示其美顏美型技術的深度剖析。 一、美顏SDK的基本原理 美顏SDK的基本原理是通過對…

maven 包管理平臺-01-maven 入門介紹 + Maven、Gradle、Ant、Ivy、Bazel 和 SBT 的詳細對比表格

拓展閱讀 maven 包管理平臺-01-maven 入門介紹 Maven、Gradle、Ant、Ivy、Bazel 和 SBT 的詳細對比表格 maven 包管理平臺-02-windows 安裝配置 mac 安裝配置 maven 包管理平臺-03-maven project maven 項目的創建入門 maven 包管理平臺-04-maven archetype 項目原型 ma…

docker單機啟動mysql、redis容器命令

將your_path、your_password、your_version替換成自己需要的 mysql docker run -d -p 3306:3306 --name mysql --restartalways \ -v /your_path/my.cnf:/etc/mysql/my.cnf \ -v /your_path/log:/logs \ -v /your_path/mysql:/var/lib/mysql \ -e MYSQL_ROOT_PASSWORDyour_pa…

java 企業培訓管理系統Myeclipse開發mysql數據庫web結構jsp編程計算機網頁項目

一、源碼特點 java 企業培訓管理系統是一套完善的java web信息管理系統&#xff0c;對理解JSP java編程開發語言有幫助&#xff0c;系統具有完整的源代碼和數據庫&#xff0c;系統主要采用B/S模式開發。開發環境為TOMCAT7.0,Myeclipse8.5開發&#xff0c;數據庫為Mysql5.0&…

UCWSC

feature fusion neural network based on a decomposition mechanism (FFDM) 輔助信息 作者未提供代碼

學習大數據,所必需的java基礎(6)

文章目錄 集合Set集合介紹HashSet集合的介紹和使用LinkedHashSet的介紹以及使用哈希值哈希值的計算方式HashSet的存儲去重的過程 Map集合Map的介紹HashMap的介紹以及使用HashMap的兩種遍歷方式方式1&#xff1a;獲取key&#xff0c;然后再根據key獲取value方式2&#xff1a;同時…

【Sql Server】Update中的From語句,以及常見更新操作方式

歡迎來到《小5講堂》&#xff0c;大家好&#xff0c;我是全棧小5。 這是《Sql Server》系列文章&#xff0c;每篇文章將以博主理解的角度展開講解&#xff0c; 特別是針對知識點的概念進行敘說&#xff0c;大部分文章將會對這些概念進行實際例子驗證&#xff0c;以此達到加深對…

Docker技術概論(4):Docker CLI 基本用法解析

Docker技術概論&#xff08;4&#xff09; Docker CLI 基本用法解析 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:http…

Python實現PPT演示文稿中視頻的添加、替換及提取

無論是在教室、會議室還是虛擬會議中&#xff0c;PowerPoint 演示文稿都已成為一種無處不在的工具&#xff0c;用于提供具有影響力的可視化內容。PowerPoint 提供了一系列增強演示的功能&#xff0c;在其中加入視頻的功能可以大大提升整體體驗。視頻可以傳達復雜的概念、演示產…

ArkTS中的路由跳轉和HTTP數據請求

路由跳轉 步驟1&#xff1a;找到箭頭所指的文件&#xff0c;在其中添加已創建的頁面 步驟2&#xff1a;導包 步驟3&#xff1a; HTTP數據請求 步驟1&#xff1a;導包 > import http from ohos.net.http; 步驟2&#xff1a;&#xff08;如果需要在頁面加載前請求&#xf…

TcpServer服務器管理模塊(模塊十)

目錄 類功能 類定義 類實現 編譯測試 server.cc gdb測試斷點 忽略SIGPIPE信號 類功能 類定義 // TcpServer服務器管理模塊(即全部模塊的整合) class TcpServer { private:uint64_t _next_id; // 這是一個自動增長的連接IDint _port;i…

Linux學習-C語言-運算符

目錄 算術運算符&#xff1a; - * /:不能除0 %:不能對浮點數操作 &#xff1a;自增與運算符 i&#xff1a;先用再加 i:先加再用 --&#xff1a;自減運算符 常量&#xff0c;表達式不可以&#xff0c;--&#xff0c;變量可以 賦值運算符 三目運算符 逗號表達式 size…

alpine創建lnmp環境alpine安裝nginx+php5.6+mysql

前言 制作lnmp環境&#xff0c;你可以在alpine基礎鏡像中安裝相關的服務&#xff0c;也可以直接使用Dockerfile創建自己需要的環境鏡像。 注意&#xff1a;提前確認自己的alpine版本&#xff0c;本次創建基于alpine3.6進行創建&#xff0c;官方在一些版本中刪除了php5 1、拉取…

JS正則02——js正則表達式中常用的方法、常見修飾符的使用詳解以及各種方法使用情況示例

JS正則02——js正則表達式中常用的方法、常見修飾符的使用詳解以及各種方法使用情況示例 1. 前言1.1 簡介1.2 js正則特殊字符即使用示例 2. 創建正則表達式的方式2.1 兩種創建正則表達式的方式2.2 關于修飾符 3. 正則表達式中常用的方法3.1 test() 方法——正則表達式對象的方法…

Vue之監測數據的原理(對象)

大家有沒有想過&#xff0c;為什么vue可以監測到數據發生改變&#xff1f;其實底層借助了Object.defineProperty&#xff0c;底層有一個Observer的構造函數 讓我為大家簡單的介紹一下吧&#xff01; 我用對象為大家演示一下 const vm new Vue({el: "#app",data: {ob…

Python列表操作函數

在Python中&#xff0c;列表&#xff08;list&#xff09;是一種可變的數據類型&#xff0c;它包含一系列有序的元素。Python提供了一系列內置的函數和方法來操作列表。以下是一些常用的Python列表操作函數和方法&#xff1a; 列表方法 append(x) 將元素x添加到列表的末尾。 …

文獻速遞:帕金森的疾病分享--多模態機器學習預測帕金森病

文獻速遞&#xff1a;帕金森的疾病分享–多模態機器學習預測帕金森病 Title 題目 Multi-modality machine learning predicting Parkinson’s disease 多模態機器學習預測帕金森病 01 文獻速遞介紹 對于漸進性神經退行性疾病&#xff0c;早期和準確的診斷是有效開發和使…

Linux按鍵輸入實驗-對按鍵驅動進行測試

一. 簡介 前面學習在設備樹文件中創建按鍵的設備節點,并實現對按鍵驅動代碼的編寫,文章地址如下:Linux按鍵輸入實驗-創建按鍵的設備節點-CSDN博客Linux按鍵輸入實驗-按鍵的字符設備驅動代碼框架-CSDN博客Linux按鍵輸入實驗-按鍵的GPIO初始化-CSDN博客 本文對所實現的按鍵驅…