【論文復現】LSTM長短記憶網絡

LSTM

  • 前言
  • 網絡架構
    • 總線
    • 遺忘門
    • 記憶門
    • 記憶細胞
    • 輸出門
  • 模型定義
    • 單個LSTM神經元的定義
    • LSTM層內結構的定義
  • 模型訓練
  • 模型評估
  • 代碼細節
    • LSTM層單元的首尾的處理
    • 配置Tensorflow的GPU版本

前言

LSTM作為經典模型,可以用來做語言模型,實現類似于語言模型的功能,同時還經常用于做時間序列。由于LSTM的原版論文相關版權問題,這里以colah大佬的博客為基礎進行講解。之前寫過一篇Tensorflow中的LSTM詳解,但是原理部分跟代碼部分的聯系并不緊密,實踐性較強但是如果想要進行更加深入的調試就會出現原理性上面的問題,因此特此作文解決這個問題,想要用LSTM這個有趣的模型做出更加好的機器學習效果😊。

網絡架構

LSTM框架圖
這張圖展示了LSTM在整體結構,下面就開始分部分介紹中間這個東東。

總線

在這里插入圖片描述
這條是總線,可以實現神經元結構的保存或者更改,如果就是像上圖一樣一條總線貫穿不做任何改變,那么就是不改變細胞狀態。那么如果想要改變細胞狀態怎么辦?可以通過來實現,這里的門跟高中生物中學的神經興奮閾值比較像,用數學來表示就是sigmoid函數或者其他的激活函數,當門的輸入達到要求,門就會打開,允許當前門后面的信息“穿過”門改變主線上面傳遞的信息,如果把每一個神經元看成一個時間節點,那么從上一個時間節點傳到下一個時間節點過程中的門的開啟與關閉就實現了時間序列數據的信息傳遞。
在這里插入圖片描述

遺忘門

在這里插入圖片描述
首先是遺忘門,這個門的作用是決定從上一個神經元傳輸到當前神經元的數據丟棄的程度,如果經過sigmoid函數以后輸出0表示全部丟棄,輸出1表示全部保留,這個層的輸入是舊的信息和當前的新信息。

σ \sigma σ:sigmoid函數
W f W_f Wf?:權重向量
b f b_f bf?:偏置項,決定丟棄上一個時間節點的程度,如果是正數,表示更容易遺忘,如果是負數,表示比較容易記憶
h t ? 1 h_{t-1} ht?1?:上一個時刻的輸入
x t x_t xt?:當前層的輸入

記憶門

在這里插入圖片描述
接下來是記憶門,這個門決定要記住什么信息,同時決定按照什么程度記住上一個狀態的信息。

i t i_t it?:在時間步t時刻的輸入門激活值,計算方法跟上面的遺忘門是一樣的,只是目的不一樣,這里是記憶
C ~ t \tilde{C}_{t} C~t?:表示上一個時刻的信息和當前時刻的信息的集合,但是是規則化到[-1,1]這個范圍內了的

記憶細胞

在這里插入圖片描述
有了上面的要記憶的信息和要丟棄的信息,記憶細胞的功能就可以得到實現,用 f t f_t ft?這個標量決定上一個狀態要遺忘什么,用 i t i_t it?這個標量決定上一個狀態要記住什么以及當前狀態的信息要記住什么。這樣就形成了一個記憶閉環了。

輸出門

在這里插入圖片描述
最后,在有了記憶細胞以后不僅僅不要將當前細胞狀態記住,還要將當前的信息向下一層繼續傳輸,實現公式中的狀態轉移。

o t o_t ot?:跟前面的門公式都一樣,但是功能是決定輸出的程度
h t h_t ht?:將輸出規范到[-1,1]的區間,這里有兩個輸出的原因是在構建LSTM網絡的時候需要有縱向向上的那個 h t h_t ht?,然而在當前層的LSTM的神經元之間還是首尾相接的😍。

模型定義

單個LSTM神經元的定義


# 定義單個LSTM單元
# 定義單個LSTM單元
class My_LSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size# 初始化門的權重和偏置,由于每一個神經元都有自己的偏置,所以在定義單元內部定義self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bf = nn.Parameter(torch.Tensor(hidden_size))self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bi = nn.Parameter(torch.Tensor(hidden_size))self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bo = nn.Parameter(torch.Tensor(hidden_size))self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bg = nn.Parameter(torch.Tensor(hidden_size))# 初始化輸出層的權重和偏置self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))self.b = nn.Parameter(torch.Tensor(output_size))# 用于計算每一種權重的函數def cal_weight(self, input, weight, bias):return F.linear(input, weight, bias)# x是輸入的數據,數據的格式是(batch, seq_len, input_size),包含的是batch個序列,每個序列有seq_len個時間步,每個時間步有input_size個特征def forward(self, x):# 初始化隱藏層和細胞狀態h = torch.zeros(1, 1, self.hidden_size).to(x.device)c = torch.zeros(1, 1, self.hidden_size).to(x.device)# 遍歷每一個時間步for i in range(x.size(1)):input = x[:, i, :].view(1, 1, -1) # 取出每一個時間步的數據# 計算每一個門的權重f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遺忘門i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 輸入門o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 輸出門C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候選值# 更新細胞狀態c = f * c + i * C_# 更新隱藏層h = o * torch.tanh(c) # 將輸出標準化到-1到1之間output = self.cal_weight(h, self.W, self.b) # 計算輸出return output

LSTM層內結構的定義

class My_LSTMNetwork(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTMNetwork, self).__init__()self.hidden_size = hidden_sizeself.lstm = My_LSTM(input_size, hidden_size)  # 使用自定義的LSTM單元self.fc = nn.Linear(hidden_size, output_size)  # 定義全連接層def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))  # LSTM層的前向傳播out = self.fc(out[:, -1, :])  # 全連接層的前向傳播return out

模型訓練

history = model.fit(trainX, trainY, batch_size=64, epochs=50, validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型評估

為了更加直觀展示,這里用畫圖的方法進行結果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代碼細節

LSTM層單元的首尾的處理

  • 首部:由于第一個節點不用接受來自上一個節點的輸入,不需要有輸入,當然也有一些是添加標識。

  • 尾部:由于已經進行到當前層的最后一個節點,因此輸出只需要向下一層進行傳遞而不用向下一個節點傳遞,添加標識也是可以的。

配置Tensorflow的GPU版本

這一篇寫的比較好,我自己的硬件環境如下圖所示,需要的可以借鑒一下,當然也可以在我提供的代碼鏈接直接用我給的environment.yml一鍵構建環境😃。
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/16114.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/16114.shtml
英文地址,請注明出處:http://en.pswp.cn/web/16114.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue3的proxy如何取代object和defineproperty

在 Vue 2.x 中,為了響應式地追蹤對象屬性的變化,Vue 使用了 Object.defineProperty 方法。但是,Object.defineProperty 有一些限制,比如它不能追蹤屬性的添加或刪除,也不能直接用于數組或對象原型鏈上的屬性。 Vue 3.…

【Torch學習筆記】

作者:zjk 和 的區別是逐元素相乘,是矩陣相乘 cat stack 的區別 cat stack 是用于沿新維度將多個張量堆疊在一起的函數。它要求所有輸入張量具有相同的形狀,并在指定的新維度上進行堆疊。

【NumPy】關于numpy.mean()函數,看這一篇文章就夠了

🧑 博主簡介:阿里巴巴嵌入式技術專家,深耕嵌入式人工智能領域,具備多年的嵌入式硬件產品研發管理經驗。 📒 博客介紹:分享嵌入式開發領域的相關知識、經驗、思考和感悟,歡迎關注。提供嵌入式方向…

Android11熱點啟動和關閉

Android官方關于Wi-Fi Hotspot (Soft AP) 的文章:https://source.android.com/docs/core/connect/wifi-softap?hlzh-cn 在 Android 11 的WifiManager類中有一套系統 API 可以控制熱點的開和關,代碼如下: 開啟熱點: // SoftApC…

Vue 父組件使用refs來直接訪問和修改子組件的屬性或調用子組件的方法

步驟 1: 在子組件中定義要被修改的屬性或方法 首先,在子組件中定義你想要父組件能夠修改或調用的屬性或方法。例如,我們有一個名為MyChildComponent的子組件,它有一個名為childData的數據屬性和一個名為updateData的方法。 // 子組件 MyChi…

國際版Tiktok抖音運營流量實戰班:賬號定位/作品發布/熱門推送/等等-13節

課程目錄 1-tiktok賬號定位 1.mp4 2-tiktok作品發布技巧 1.mp4 3-tiktok數據功能如何開通 1.mp4 4-tiktok熱門視頻推送機制 1.mp4 5-如何發現熱門視頻 1.mp4 6-如何發現熱門音樂 1.mp4 7-如何尋找熱門標簽 1.mp4 8-如何尋找垂直熱門視頻 1.mp4 9-如何發現熱門挑戰賽 1…

【Python特征工程系列】一文教你使用PCA進行特征分析與降維(案例+源碼)

這是我的第287篇原創文章。 一、引言 主成分分析(Principal Component Analysis, PCA)是一種常用的降維技術,它通過線性變換將原始特征轉換為一組線性不相關的新特征,稱為主成分,以便更好地表達數據的方差。 在特征重要…

DAMA數據管理知識體系必背18張框圖

近期對數據管理知識體系中比較重要的框圖進行了梳理總結,總共有18張框圖,供大家參考。主要涉及數據管理、數據治理階段模式、數據安全需求、主數據管理關鍵步驟,主數據架構、DW架構、數據科學的7個階段、數據倉庫建設活動、信息收斂三角、大數據分析架構圖、數據管理成熟度等…

QGIS開發筆記(二):Windows安裝版二次開發環境搭建(上):安裝OSGeo4W運行依賴其Qt的基礎環境Demo

若該文為原創文章,轉載請注明原文出處 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/139136356 長沙紅胖子Qt(長沙創微智科)博文大全:開發技術集合(包含Qt實用技術、樹莓派、三維、OpenCV…

如果返回的json 中有 ‘///’ 轉換

// 將返回數據的三條/和替換空 rowData.Jsonobj rowData.Jsonobj .replace(/^\s*\/\/\/.*$/gm, //); // 將返回的替換成" 并且外面加個"" rowData.Jsonobj "${rowData.Jsonobj .replace(//g, ")}"; // 轉換回來數據用兩個 JSON.parse(JSON.par…

Charles抓包App_https_夜神模擬器

Openssl安裝 下載安裝 下載地址: http://slproweb.com/products/Win32OpenSSL.html 我已經下載好了64位的,也放出來: 鏈接:https://pan.baidu.com/s/1Nkur475YK48_Ayq_vEm99w?pwdf4d7 提取碼:f4d7 --來自百度網…

地下城游戲(leetcode)

個人主頁&#xff1a;Lei寶啊 愿所有美好如期而遇 地下城游戲https://leetcode.cn/problems/dungeon-game/description/ 圖解分析&#xff1a; 代碼 class Solution { public:int calculateMinimumHP(vector<vector<int>>& vv) {int row vv.size(), col …

Zookeeper 安裝教程和使用指南

一、Zookeeper介紹 ZooKeeper 是 Apache 軟件基金會的一個開源項目&#xff0c;主要基于 Java 語言實現。 Apache ZooKeeper 是一個開源的分布式應用程序協調服務&#xff0c;提供可靠的數據管理通知、數據同步、命名服務、分布式配置服務、分布式協調等服務。 關鍵特性 分布…

Nginx實戰(安裝部署、常用命令、反向代理、負載均衡、動靜分離)

文章目錄 1. nginx安裝部署1.1 windows安裝包1.2 linux-源碼編譯1.3 linux-docker安裝 2. nginx介紹2.1 簡介2.2 常用命令2.3 nginx運行原理2.3.1 mater和worker2.3.3 Nginx 的工作原理 2.4 nginx的基本配置文件2.4.1 location指令說明 3. nginx案例3.1 nginx-反向代理案例013.…

數據結構和算法|排序算法系列(三)|插入排序(三路排序函數std::sort)

首先需要你對排序算法的評價維度和一個理想排序算法應該是什么樣的有一個基本的認知&#xff1a; 《Hello算法之排序算法》 主要內容來自&#xff1a;Hello算法11.4 插入排序 插入排序的整個過程與手動整理一副牌非常相似。 我們在未排序區間選擇一個基準元素&#xff0c;將…

移動云以深度融合之服務,令“大”智慧貫穿云端

移動云助力大模型&#xff0c;開拓創新領未來。 云計算——AI模型的推動器。 當前人工智能技術發展的現狀和趨勢&#xff0c;以及中國在人工智能領域的發展策略和成就。確實&#xff0c;以 ChatGPT 為代表的大型語言模型在自然語言處理、文本生成、對話系統等領域取得了顯著的…

項目管理:敏捷實踐框架

一、初識敏捷 什么是敏捷(Agile)?敏捷是思維方式。 傳統開發模型 央企,國企50%-60%需求分析。整體是由文檔控制的過程管理。 傳統軟件開發面臨的問題: 交付周期長:3-6個月甚至更長溝通效果差:文檔化溝通不及時按時發布低:技術債增多無法發版團隊士氣弱:死亡行軍不關注…

Vmware 17安裝 CentOS9

前言 1、提前下載好需要的CentOS9鏡像&#xff0c;下載地址&#xff0c;這里下載的是x86_64 2、提前安裝好vmware 17&#xff0c;下載地址 &#xff0c;需要登錄才能下載 安裝 1、創建新的虛擬機 2、在彈出的界面中選擇對應的類型&#xff0c;我這里選擇自定義&#xff0c;點…

python command亂碼怎么解決

python command亂碼怎么解決&#xff1f;具體方法如下&#xff1a; 先引入import sys 再加一句&#xff1a;typesys.getfilesystemencoding() 然后在輸出亂碼的數據的后面加上“.decode(utf-8).encode(type)”。 比如輸入“ss”亂碼。 就寫成print ss.decode(utf-8).encode(typ…

USB - Host controller類型介紹

USB 主機控制器類型 USB 主機控制器是計算機系統中的重要組件&#xff0c;負責管理計算機與連接的 USB 設備之間的通信。多年來&#xff0c;針對不同的 USB 標準和數據傳輸速率&#xff0c;開發了多種類型的 USB 主機控制器。以下是主要 USB 主機控制器類型的概述&#xff1a; …