《PyTorch深度學習實踐》第十三講RNN進階

一、

雙向循環神經網絡(Bidirectional Recurrent Neural Network,BiRNN)是一種常見的循環神經網絡結構。與傳統的循環神經網絡只考慮歷史時刻的信息不同,雙向循環神經網絡不僅考慮歷史時刻的信息,還考慮未來時刻的信息。

在雙向循環神經網絡中,輸入序列可以被看作是由兩個部分組成的:正向和反向。在正向部分中,輸入數據從前往后進行處理,而在反向部分中,輸入數據從后往前進行處理。這兩個部分在網絡中分別使用獨立的循環神經網絡進行計算,并將它們的輸出進行拼接。這樣,網絡就可以獲得正向和反向兩個方向的信息,并且能夠同時考慮整個序列的上下文信息。

雙向循環神經網絡的作用是在處理序列數據時,提供更全面、更準確的上下文信息,能夠捕獲序列中前后關系,對于很多序列處理任務(例如自然語言處理、語音識別等)的效果都有很大的提升。在本代碼中,設置了 bidirectional=True,意味著使用雙向 GRU,提取的特征包含了正向和反向的信息。在 GRU 層輸出后,通過 torch.cat() 將正向和反向的最后一個時間步的隱含狀態進行拼接,從而得到一個更全面的特征表示。

二、項目簡介:根據名字中的字符來預測其是哪個國家的人

代碼:

import csv
import time
import matplotlib.pyplot as plt
import numpy as np
import math
import gzip  # 用于讀取壓縮文件
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence# 一些超參數
HIDDEN_SIZE = 100
BATCH_SIZE = 256  # 一次訓練的樣本數,為256個名字
N_LAYER = 2  # RNN的層數
N_EPOCHS = 100
N_CHARS = 128  # ASCII碼一共有128個字符
USE_GPU = False # 不好意思,我沒GPU!# 構造數據集
class NameDataset(Dataset):def __init__(self, is_train_set=True):filename = '../Data/names_train.csv.gz' if is_train_set else '../Data/names_test.csv.gz'with gzip.open(filename, 'rt') as f:  # rt表示以只讀模式打開文件,并將文件內容解析為文本形式reader = csv.reader(f)rows = list(reader)  # rows是一個列表,每個元素是一個名字和國家名組成的列表self.names = [row[0] for row in rows]  # 一個很長的列表,每個元素是一個名字,字符串,長度不一,需要轉化為數字self.len = len(self.names)  # 訓練集:13374  測試集:6700self.countries = [row[1] for row in rows]  # 一個很長的列表,每個元素是一個國家名,字符串,需要編碼成數字# 下面兩行的作用其實就是把國家名編碼成數字,因為后面要用到交叉熵損失函數self.country_list = list(sorted(set(self.countries)))  # 列表,按字母表順序排序,去重后有18個國家名self.country_dict = self.getCountryDict()  # 字典,key是國家名,value是country_list的國家名對應的索引(0-17)self.country_num = len(self.country_list)  # 18# 根據樣本的索引返回姓名和國家名對應的索引,可以理解為(特征,標簽),但這里的特征是姓名,后面還需要轉化為數字,標簽是國家名對應的索引def __getitem__(self, index):return self.names[index], self.country_dict[self.countries[index]]# 返回樣本數量def __len__(self):return self.len# 返回一個key為國家名和value為索引的字典def getCountryDict(self):country_dict = dict()  # 空字典for idx, country_name in enumerate(self.country_list):country_dict[country_name] = idxreturn country_dict# 根據索引(標簽值)返回對應的國家名def idx2country(self, index):return self.country_list[index]# 返回國家名(標簽類別)的個數,18def getCountriesNum(self):return self.country_num# 實例化數據集
trainset = NameDataset(is_train_set=True)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testset = NameDataset(is_train_set=False)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)N_COUNTRY = trainset.getCountriesNum()  # 18個國家名,即18個類別# 設計神經網絡模型
class RNNClassifier(torch.nn.Module):def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):super(RNNClassifier, self).__init__()self.hidden_size = hidden_size  # 隱含層的大小,100self.n_layers = n_layers  # RNN的層數,2self.n_directions = 2 if bidirectional else 1  # 是否使用雙向RNN# 詞嵌入層:input_size是輸入的特征數(即不同詞語的個數),即128;embedding_size是詞嵌入的維度(即將詞語映射到的向量的維度),這里讓它等于了隱含層的大小,即100self.embedding = torch.nn.Embedding(input_size, hidden_size)# GRU層:input_size是輸入的特征數(這里是embedding_size,其大小等于hidden_size),即100;hidden_size是隱含層的大小,即100;n_layers是RNN的層數,2;bidirectional是是否使用雙向RNNself.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)# 全連接層:hidden_size是隱含層的大小,即100;output_size是輸出的特征數(即不同類別的個數),即18self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)def _init_hidden(self, batch_size):# 初始化隱含層,形狀為(n_layers * num_directions, batch_size, hidden_size)hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)return create_tensor(hidden)def forward(self, input, seq_lengths):# input shape:B X S -> S X Binput = input.t()  # 轉置,變成(seq_len,batch_size)batch_size = input.size(1)  # 256,一次訓練的樣本數,為256個名字,即batch_sizehidden = self._init_hidden(batch_size)# 1、嵌入層處理,input:(seq_len,batch_size) -> embedding:(seq_len,batch_size,embedding_size)embedding = self.embedding(input)# pack them upgru_input = pack_padded_sequence(embedding, seq_lengths)# output:(*, hidden_size * num_directions),*表示輸入的形狀(seq_len,batch_size)# hidden:(num_layers * num_directions, batch, hidden_size)output, hidden = self.gru(gru_input, hidden)if self.n_directions == 2:hidden_cat = torch.cat([hidden[-1], hidden[-2]],dim=1)  # hidden[-1]的形狀是(1,256,100),hidden[-2]的形狀是(1,256,100),拼接后的形狀是(1,256,200)else:hidden_cat = hidden[-1]  # (1,256,100)fc_output = self.fc(hidden_cat)  # 返回的是(1,256,18)return fc_output# 下面該函數屬于數據準備階段的延續部分,因為神經網絡只能處理數字,不能處理字符串,所以還需要把姓名轉換成數字
def make_tensors(names, countries):# 傳入的names是一個列表,每個元素是一個姓名字符串,countries也是一個列表,每個元素是一個整數sequences_and_lengths = [name2list(name) for name innames]  # 返回的是一個列表,每個元素是一個元組,元組的第一個元素是姓名字符串轉換成的數字列表,第二個元素是姓名字符串的長度name_sequences = [sl[0] for sl in sequences_and_lengths]  # 返回的是一個列表,每個元素是姓名字符串轉換成的數字列表seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])  # 返回的是一個列表,每個元素是姓名字符串的長度countries = countries.long()  # PyTorch 中,張量的默認數據類型是浮點型 (float),這里轉換成整型,可以避免浮點數比較時的精度誤差,從而提高模型的訓練效果# make tensor of name, (Batch_size,Seq_len) 實現填充0的功能seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths)):seq_tensor[idx, :seq_len] = torch.LongTensor(seq)# sort by length to use pack_padded_sequence# perm_idx是排序后的數據在原數據中的索引,seq_tensor是排序后的數據,seq_lengths是排序后的數據的長度,countries是排序后的國家seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)seq_tensor = seq_tensor[perm_idx]countries = countries[perm_idx]return create_tensor(seq_tensor), create_tensor(seq_lengths), create_tensor(countries)# 把名字轉換成ASCII碼,返回ASCII碼值列表和名字的長度
def name2list(name):arr = [ord(c) for c in name]return arr, len(arr)# 是否把數據放到GPU上
def create_tensor(tensor):if USE_GPU:device = torch.device('cuda:0')tensor = tensor.to(device)return tensor# 訓練模型
def trainModel():total_loss = 0for i, (names, countries) in enumerate(trainloader, 1):inputs, seq_lengths, target = make_tensors(names, countries)output = classifier(inputs, seq_lengths)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()if i % 10 == 0:print(f'[{timeSince(start)}] Epoch {epoch} ', end='')  # end=''表示不換行print(f'[{i * len(inputs)}/{len(trainset)}] ', end='')print(f'loss={total_loss / (i * len(inputs))}')  # 打印每個樣本的平均損失return total_loss  # 返回的是所有樣本的損失,我們并沒有用上它# 測試模型
def testModel():correct = 0total = len(testset)print('evaluating trained model ...')with torch.no_grad():for i, (names, countries) in enumerate(testloader, 1):inputs, seq_lengths, target = make_tensors(names, countries)output = classifier(inputs, seq_lengths)pred = output.max(dim=1, keepdim=True)[1]  # 返回每一行中最大值的那個元素的索引,且keepdim=True,表示保持輸出的二維特性correct += pred.eq(target.view_as(pred)).sum().item()  # 計算正確的個數percent = '%.2f' % (100 * correct / total)print(f'Test set: Accuracy {correct}/{total} {percent}%')return correct / total  # 返回的是準確率,0.幾幾的格式,用來畫圖def timeSince(since):now = time.time()s = now - sincem = math.floor(s / 60)  # math.floor()向下取整s -= m * 60return '%dmin %ds' % (m, s)  # 多少分鐘多少秒if __name__ == '__main__':classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)if USE_GPU:device = torch.device('cuda:0')classifier.to(device)# 定義損失函數和優化器criterion = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(classifier.parameters(), lr=0.001)start = time.time()print('Training for %d epochs...' % N_EPOCHS)acc_list = []# 在每個epoch中,訓練完一次就測試一次for epoch in range(1, N_EPOCHS + 1):# Train cycletrainModel()acc = testModel()acc_list.append(acc)# 繪制在測試集上的準確率epoch = np.arange(1, len(acc_list) + 1)acc_list = np.array(acc_list)plt.plot(epoch, acc_list)plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.grid()plt.show()

運行結果:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/712360.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/712360.shtml
英文地址,請注明出處:http://en.pswp.cn/news/712360.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

wireshark過濾和tcpdump抓包指令

Wireshark 過濾器的表達式,用于過濾源 IP 地址為 10.184.148.247 并且目標 TCP 端口為 1883 的數據包。啟用抓包后過濾 ip.addr 10.184.148.247 && tcp.port 1883 主機位10.184.148.19和目標端口為 8080 的操作目標 抓包前過濾 host 10.184.148.19 &…

軟件說明書怎么寫?終于有人一次性說清楚了!

每次寫軟件說明書,你是不是總是毫無頭緒,不知道從何下手?到各網站找資料,不僅格式不規范,甚至可能遺漏關鍵內容!挨一頓批不說,還浪費大把時間。別著急,編寫軟件說明書,關…

PostgreSQL開發與實戰(2)常用命令

作者&#xff1a;太陽 1、連庫相關 #連庫 $ psql -h <hostname or ip> -p <端口> [數據庫名稱] [用戶名稱] #連庫并執行命令 $ psql -h <hostname or ip> -p <端口> -d [數據庫名稱] -U <用戶名> -c "運行一個命令;"備注&#xff1…

從理論到落地,大模型評測體系綜合指南

1956年夏&#xff0c;“人工智能” 這一概念被提出。距今已有近70年的發展歷史。中國科學院將其劃分為六個階段&#xff1a;起步發展期&#xff08;1956年—1960s&#xff09;&#xff0c;反思發展期&#xff08;1960s-1970s&#xff09;,應用發展期&#xff08;1970s-1980s),低…

SpringBoot集成Activiti案例

前言 Activiti項目是一項新的基于Apache許可的開源BPM平臺&#xff0c;從基礎開始構建&#xff0c;旨在提供支持新的BPMN 2.0標準&#xff0c;包括支持對象管理組&#xff08;OMG&#xff09;&#xff0c;面對新技術的機遇&#xff0c;諸如互操作性和云架構&#xff0c;提供技…

3.1log | 62.不同路徑,63. 不同路徑 II,343. 整數拆分,96.不同的二叉搜索樹

62.不同路徑 class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> dp(m,vector<int>(n,0));for(int i0;i<n;i) dp[0][i]1;for(int i0;i<m;i) dp[i][0]1;for(int i1;i<m;i){for(int j1;j<n;j){dp[i][j]dp[i][j-1]dp[i-…

c++八股文:c++編譯與內存管理

文章目錄 1. c內存管理2. 堆與棧3.變量定義與生命周期4.內存對齊5.內存泄露6.智能指針7.new 和 malloc 有什么區別8.delete和free的區別9.什么野指針&#xff0c;怎么產生的&#xff0c;如何避免野指針10.野指針和指針懸浮的區別11.字符串操作函數參考 1. c內存管理 c在運行程…

LeetCode刷題--- 乘積為正數的最長子數組長度

個人主頁&#xff1a;元清加油_【C】,【C語言】,【數據結構與算法】-CSDN博客 個人專欄 力扣遞歸算法題 http://t.csdnimg.cn/yUl2I 【C】 ??????http://t.csdnimg.cn/6AbpV 數據結構與算法 ???http://t.csdnimg.cn/hKh2l 前言&#xff1a;這個專欄主要講述動…

ScheduledThreadPoolExecutor學習

簡介 ScheduledThreadPoolExecutor 是 Java 中的一個類&#xff0c;它屬于 java.util.concurrent 包。這個類是一個線程池&#xff0c;用于在給定的延遲后運行命令&#xff0c;或者定期地執行命令。它是 ThreadPoolExecutor 的一個子類&#xff0c;專門用于處理需要定時或周期…

解釋索引是什么以及它們是如何提高查詢性能的

索引在數據庫管理系統中是一個重要的數據結構&#xff0c;用于幫助快速檢索數據庫表中的數據。它可以被看作是一個指向表中數據的指針列表&#xff0c;這些指針按照某種特定的順序&#xff08;如字母順序或數字順序&#xff09;排列。索引的工作原理類似于書籍的目錄&#xff1…

Python爬蟲實戰第二例【二】

零.前言&#xff1a; 本文章借鑒&#xff1a;Python爬蟲實戰&#xff08;五&#xff09;&#xff1a;根據關鍵字爬取某度圖片批量下載到本地&#xff08;附上完整源碼&#xff09;_python爬蟲下載圖片-CSDN博客 大佬的文章里面有API的獲取&#xff0c;在這里我就不贅述了。 一…

kitex 入門和基于grpc的使用

&#x1f4d5;作者簡介&#xff1a; 過去日記&#xff0c;致力于Java、GoLang,Rust等多種編程語言&#xff0c;熱愛技術&#xff0c;喜歡游戲的博主。 &#x1f4d7;本文收錄于kitex系列&#xff0c;大家有興趣的可以看一看 &#x1f4d8;相關專欄Rust初階教程、go語言基礎系…

【Web】青少年CTF擂臺挑戰賽 2024 #Round 1 wp

好家伙&#xff0c;比賽結束了還有一道0解web題是吧( 隨緣寫點wp(簡單過頭&#xff0c;看個樂就好) 目錄 EasyMD5 PHP的后門 PHP的XXE Easy_SQLi 雛形系統 EasyMD5 進來是個文件上傳界面 說是只能上傳pdf&#xff0c;那就改Content-Type為application/pdf&#xff0c;改…

11.盛最多水的容器

題目&#xff1a;給定一個長度為 n 的整數數組 height 。有 n 條垂線&#xff0c;第 i 條線的兩個端點是 (i, 0) 和 (i, height[i]) 。 找出其中的兩條線&#xff0c;使得它們與 x 軸共同構成的容器可以容納最多的水。 返回容器可以儲存的最大水量。 解題思路&#xff1a;可以…

判斷閏年(1000-2000)

判斷規則&#xff1a;1.能被4整除&#xff0c;不能被100整除是閏年,2.能被400整除是閏年 #include <stdio.h>int is_leap_year(int n){if((n % 400 0)||((n % 4 0)&&(n % 100 ! 0)))return 1;elsereturn 0; } int main() {int i 0;int count 0;for(i 1000;…

基于PHP的在線英語學習平臺

有需要請加文章底部Q哦 可遠程調試 基于PHP的在線英語學習平臺 一 介紹 此在線英語學習平臺基于原生PHP開發&#xff0c;數據庫mysql。系統角色分為學生&#xff0c;教師和管理員。(附帶參考設計文檔) 技術棧&#xff1a;phpmysqlphpstudyvscode 二 功能 學生 1 注冊/登錄/…

C++/Python簡單練手題

前言 最近需要開始使用python&#xff0c;但是對python了解的并不多&#xff0c;于是先從很早之前剛學C時寫過的一些練手題開始&#xff0c;使用python來實現相同的功能&#xff0c;在溫習python基礎語法的同時&#xff0c;也一起來感受感受python的魅力 99乘法表 c&#xf…

kettle開發-Day43-加密環境下運行作業

前言&#xff1a; 金三銀四&#xff0c;開年第一篇我們來介紹下&#xff0c;怎么在加密情況下運行我們的kettle作業及任務。無疑現在所有企業都認識到加密的重要性&#xff0c;加密后的文件在對外傳輸的時候不能被訪問&#xff0c;訪問時出現一堆亂碼&#xff0c;同時正常的應用…

1分鐘學會Python字符串前后綴與編解碼

1.前綴和后綴 前綴和后綴指的是&#xff1a;字符串是否以指定字符開頭和結尾 2.startswith() 判斷字符串是否以指定字符開頭&#xff0c;若是返回True&#xff0c;若不是返回False str1 "HelloPython"print(str1.startswith("Hello")) # Trueprint…

Navicat Premium 16:打破數據庫界限,實現高效管理mac/win版

Navicat Premium 16是一款功能強大的數據庫管理工具&#xff0c;旨在幫助用戶更輕松地連接、管理和保護各種數據庫。該軟件支持多種數據庫系統&#xff0c;如MySQL、Oracle、SQL Server、PostgreSQL等&#xff0c;并提供了直觀的圖形界面&#xff0c;使用戶能夠輕松地完成各種數…