Python----循環神經網絡(LSTM:長短期記憶網絡)

一、RNN的長期依賴問題

? ? ? ??可以看到序列越長累乘項項數越多,項數越多就可能會讓累乘結果越小,此時對于W 的更新就取決于第一項或者是前幾項,也就是RNN模型會丟失很多較遠時刻的信息而 更關注當前較近的幾個時刻的信息,即沒有很好的長期依賴。 通俗來說就是模型記不住以前的東西。但很多時候我們都希望模型記得更久的信息。

二、LSTM模型結構

????????為了解決RNN的長期依賴問題,研究者對傳統RNN的結構進行了優化,提出了 LSTM。

????????通俗來說,RNN就好比是一個給什么都想要的人, 而LSTM是一個給東西還得挑一挑,挑一些有用的人。 這就導致RNN東西越來越多,多到放不下,然后直接把以前的東西丟掉,而LSTM從 一開始就精挑細選把沒用的丟掉,因此在容量一定的情況下LSTM可以裝入更長時間 的信息,并且這些信息都是相對更有用的。

????????LSTM的這種特性是通過門結構來實現的。‘門’的作用就是控制信息保留或丟棄的程 度。

注意:

????????這里的“門”不是只有開關狀態,即是否全部保留或者丟棄,而是保留或者 丟棄的程度。

2.1、輸入門(input gate)

????????sigmoid函數的輸出范圍是0到1,這是一個很 好的性質,我們可以把它的輸出理解為一個概率值或者是權重,即需要保留的程度, 當輸出為1時為全保留,當輸出為0時為全部不保留或者說全部遺忘(實際上, sigmoid函數不會就輸出0或者1),當輸出置于0和1之間時就是以一定程度保留。?

????????我們可以看到輸入依然是上一時間步的隱藏狀態和當前時間 步的輸入,也就是這個保留的程度是通過上一時間步的隱藏狀態和當前時間步的輸入 學習得到的,也就是說LSTM模型對新輸入進行挑選的過程,而這種挑選又是基于以 前的經驗進行的。 現在我們已經單獨分析完輸入門的兩個分支了,它們結合就很簡單了,之間進行,i_t表示的是保留的程度是一個0到1之間g_t是傳統RNN 的部分表示原始的輸出,那么將他們相乘就很容易理解了,就是選擇一定程度的原始 輸入作為輸出。?

2.2、遺忘門(forget gate)

????????sigmoid的作用就很清晰了,充當的就是‘門’的結構,即程度。在組 件中點擊LSTM下的forget gate 可以看到標紅部分就是遺忘門的結構。依然是輸入上 一時間步的隱藏狀態和當前時間步的輸入,經過sigmoid函數輸出,輸出的就是一個 介于0和1之間表示程度的值 。

說是‘遺忘’但本質上還是保留的程度

2.3、update cell state(細胞更新單元)

????????可以看到這個分支是隨著時間步進行更新的,遺忘門就是控制模型記憶的, 控制保留多少以前的記憶。然后加上 i_t和g_t 相乘的結果,實際上就是加上輸入門的輸 入結果,也就是說將多少當前時間步的信息加入到記憶之中。總的來說, 分支的信 息走向就是:先選擇性保留之前的記憶,再選擇性加入當前的信息,得到新的記憶。?

2.4、輸出門(output gate)

????????通過sigmoid函數控制輸出的程度,然后當前時刻的記憶經過tanh激活,再將兩者相乘得到了 即隱藏狀態的輸出。

import torch
import numpy as np
from torch import nn# 1. 字符輸入
text = "In Beijing Sarah bought a basket of apples In Guangzhou Sarah bought a basket of bananas"# 設置隨機種子,保證實驗的可重復性
torch.manual_seed(1)# 3. 數據集劃分
# input_seq 是輸入序列,去掉了最后一個字符
input_seq = [text[:-1]]
# output_seq 是目標序列,去掉了第一個字符,與 input_seq 錯開一位
output_seq = [text[1:]]
print("input_seq:", input_seq)
print("output_seq:", output_seq)# 4. 數據編碼:one-hot 編碼
# 獲取文本中所有不重復的字符
chars = set(text)
# 將字符排序,保證編碼的一致性
chars = sorted(list(chars))
print("chars:", chars)
# 創建字符到數字的映射字典
char2int = {char: ind for ind, char in enumerate(chars)}
print("char2int:", char2int)
# 創建數字到字符的映射字典
int2char = dict(enumerate(chars))
print("int2char:", int2char)
# 將輸入序列中的字符轉換為數字編碼
input_seq = [[char2int[char] for char in seq] for seq in input_seq]
print("input_seq:", input_seq)
# 將輸出序列中的字符轉換為數字編碼
output_seq = [[char2int[char] for char in seq] for seq in output_seq]
print("output_seq:", output_seq)# one-hot 編碼函數,用于將數字編碼轉換為 one-hot 向量
def one_hot_encode(seq, bs, seq_len, size):# 創建一個形狀為 (batch_size, seq_len, vocab_size) 的零矩陣features = np.zeros((bs, seq_len, size), dtype=np.float32)# 遍歷 batch 中的每個序列for i in range(bs):# 遍歷序列中的每個時間步for u in range(seq_len):# 將對應字符的索引位置設置為 1.0features[i, u, seq[i][u]] = 1.0# 將 numpy 數組轉換為 PyTorch 張量return torch.tensor(features, dtype=torch.float32)# 對輸入序列進行 one-hot 編碼
input_seq = one_hot_encode(input_seq, 1, len(text) - 1, len(chars))
# 將輸出序列轉換為 PyTorch 長整型張量,并調整形狀為 (seq_len * batch_size)
output_seq = torch.tensor(output_seq[0], dtype=torch.long).view(-1)
print("output_seq:", output_seq)# 5. 定義前向模型
class Model(nn.Module):def __init__(self, input_size, hidden_size, out_size):super(Model, self).__init__()self.hidden_size = hidden_size# 定義一個 LSTM 層,輸入維度為 input_size,隱藏層維度為 hidden_size,層數為 1,batch_first=True 表示輸入張量的第一個維度是 batch sizeself.lstm1 = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)# 定義一個全連接層,將 LSTM 的輸出映射到詞匯表大小self.fc1 = nn.Linear(hidden_size, out_size)def forward(self, x):# 通過 LSTM 層得到輸出和隱藏狀態# out 的形狀為 (batch_size, seq_len, hidden_size)# hidden 是一個包含 (h_n, c_n) 的元組,每個的形狀為 (num_layers, batch_size, hidden_size)out, hidden = self.lstm1(x)# 將 LSTM 的輸出調整形狀為 (seq_len * batch_size, hidden_size),以便輸入到全連接層x = out.contiguous().view(-1, self.hidden_size)# 通過全連接層得到最終的輸出x = self.fc1(x)return x, hidden# 實例化模型,輸入大小為詞匯表大小,隱藏層大小為 32,輸出大小為詞匯表大小
model = Model(len(chars), 32, len(chars))# 6. 定義損失函數和優化器
# 使用交叉熵損失函數,常用于多分類問題
cri = nn.CrossEntropyLoss()
# 使用 Adam 優化器,學習率為 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 7. 開始迭代訓練
epochs = 1000
for epoch in range(1, epochs + 1):# 通過模型得到輸出和隱藏狀態output, hidden = model(input_seq)# 計算損失loss = cri(output, output_seq)# 清空梯度optimizer.zero_grad()# 反向傳播計算梯度loss.backward()# 更新模型參數optimizer.step()# 8. 顯示頻率設置if epoch == 1 or epoch % 50 == 0:print(f"Epoch [{epoch}/{epochs}], Loss {loss:.4f}")# 預測接下來的幾個字符
input_text = "I"  # 初始輸入字符
to_be_pre_len = 20  # 預測的長度# 進行預測
for i in range(to_be_pre_len):# 將當前輸入文本轉換為字符列表chars = [char for char in input_text]# 將字符列表轉換為數字編碼的 numpy 數組character = np.array([[char2int[c] for c in chars]])# 對數字編碼進行 one-hot 編碼character = one_hot_encode(character, 1, character.shape[1], len(chars))# 將 numpy 數組轉換為 PyTorch 張量character = torch.tensor(character, dtype=torch.float32)# 將 one-hot 編碼的輸入送入模型進行預測out, hidden = model(character)# 獲取最后一個時間步輸出中概率最大的字符的索引char_index = torch.argmax(out[-1]).item()# 將預測的數字索引轉換為字符,并添加到輸入文本中input_text += int2char[char_index]
# 打印預測結果
print("預測到的:", input_text)

?

三、LSTM“不會”梯度消失和梯度爆炸的原因

3.1、RNN的梯度消失和梯度爆炸

????????梯度消失和梯度爆炸是由于RNN的在時間維度上的權值 進行了共享,導致計算梯度時會進行連乘,連乘會導致梯度消失或者梯度爆炸,但是 需要注意的是:當時間步長的時候,連乘的負面效應才會顯現的更加明顯,即意味 著:近距離(近期記憶)并不會消失,但是遠距離(連乘的多了)才會有梯度消失和 梯度爆炸的問題。也就是說:RNN 所謂梯度消失的真正含義是,梯度被近距離梯度 主導,導致模型難以學到遠距離的依賴關系。這其實和傳統的MLP結構的梯度消失和 梯度爆炸并不同,因為傳統MLP在不同的層中并不會權值共享。

3.2、LSTM為什么“不會”梯度消失和梯度爆炸

LSTM也會梯度消失和梯度爆炸!

對于現在的LSTM有三種情況:

????????1、如果我們把讓遺忘門的輸出趨近于1,例如模型初始化時會把 forget bias 設置成 較大的正數,讓遺忘門飽和),這時候遠距離梯度不消失;

????????2、遺忘門接近 0,但這時模型是故意阻斷梯度流的(例如情感分析任務中有一條樣 本 “A,但是 B”,模型讀到“但是”后選擇把遺忘門設置成 0,遺忘掉內容 A,這是合理 的);

????????3、如果 f 介于 [0, 1] 之間的情況,在這種情況下只能說 LSTM 改善(而非解決)了 梯度消失的狀況。

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/81195.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/81195.shtml
英文地址,請注明出處:http://en.pswp.cn/web/81195.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【跨端框架檢測】使用adb logcat檢測Android APP使用的跨端框架方法總結

目錄 Weex 跨端框架使用了uni-app的情況區分使用了uni-app還是Weex 判斷使用了Xamarin判斷使用了KMM框架判斷使用了 ??Ionic 框架判斷使用了Cordova框架判斷使用了Capacitor 框架使用了React Native框架使用了QT框架使用了Cocos框架使用了Electron 框架使用了flutter 框架使用…

以加減法計算器為例,了解C++命名作用域與函數調用

************* C topic: 命名作用域與函數調用 ************* The concept is fully introducted in the last artical. Please refer to 抽象:C命名作用域與函數調用-CSDN博客 And lets make a calculator to review the basic structure in c. 1、全局函數 A…

AIGC小程序項目

一、文生文功能 (1)前端部分 使用 Pinia 狀態管理庫創建的聊天機器人消息存儲模塊,它實現了文生文(文本生成文本)的核心邏輯。 1.Pinia狀態管理 這個模塊管理兩個主要狀態: messages:存儲所…

Axios中POST、PUT、PATCH用法區別

在 Axios 中,POST、PUT 和 PATCH 是用于發送 HTTP 請求的三種不同方法,它們的核心區別源自 HTTP 協議的設計語義。以下是它們的用法和區別: 1. POST 語義:用于創建新資源。 特點: 非冪等(多次調用可能產生…

[爬蟲知識] Cookie與Session

相關實戰案例:[爬蟲實戰] 爬取小說標題與對應內容 相關爬蟲專欄:JS逆向爬蟲實戰 爬蟲知識點合集 爬蟲實戰案例 一、引入場景 在http協議中,瀏覽器是無狀態(即無記憶)的,對于請求與響應的產生數據&#…

怎樣改變中斷優先級?

在STM32中改變中斷優先級可以通過STM32CubeMX配置和代碼中設置兩種方式來實現。以下以STM32F1系列為例進行說明: 使用STM32CubeMX配置 打開工程:在STM32CubeMX中打開你的工程。進入NVIC配置:在Pinout & Configuration選項卡中,點擊NVIC進入中斷向量控制器配置界面。選…

科學計算中的深度學習模型精解:CNN、U-Net 和 Diffusion Models

關鍵要點 模型概述:卷積神經網絡(CNN)、U-Net 和 Diffusion Models 是深度學習中的核心模型,廣泛應用于科學計算任務,如偏微分方程(PDE)求解、圖像分割和數據生成。科學計算應用:CNN 可用于高效求解 PDEs,U-Net 擅長醫學圖像分割和材料分析,Diffusion Models 在生成合…

解決Docker無法拉取鏡像問題:Windows系統配置鏡像加速全指南

問題背景 在使用 Docker 時,你是否遇到過以下報錯? Unable to find image ‘mysql:latest’ locally docker: Error response from daemon: Get “https://registry-1.docker.io/v2/”: dial tcp 128.242.250.155:443: i/o timeout. 這類問題通常是由于…

Spring AI 使用教程

Spring AI 使用教程(2025年5月24日更新) 一、環境搭建與項目初始化 創建Spring Boot項目 使用IDEA或Spring Initializr創建項目,選擇JDK 17或更高版本(推薦21)。勾選依賴項:Spring Web、Lombok,…

iOS 直播特殊禮物特效實現方案(Swift實現,超詳細!)

特殊禮物特效是提升直播互動體驗的關鍵功能,下面我將詳細介紹如何在iOS應用中實現各種高級禮物特效。 基礎特效類型 1.1 全屏動畫特效 class FullScreenAnimationView: UIView {static func show(with gift: GiftModel, in view: UIView) {let effectView FullS…

分布式事務之Seata

概述 Seata有四種模式 AT模式:無侵入式的分布式事務解決方案,適合不希望對業務進行改造的場景,但由于需要添加全局事務鎖,對影響高并發系統的性能。該模式主要關注多DB訪問的數據一致性,也包括多服務下的多DB數據訪問…

信息收集與搜索引擎

6.1 常見的搜索引擎(一、二) 6.1.1 通用搜索引擎 Google/Bing: 用途:基礎信息收集(域名、子域名、敏感文件)。 高級語法: site:target.com:限定搜索目標域名。 filetype:pdf&am…

【Java項目測試報告】:在線聊天平臺(Online-Chat)

被測試項目已部署:登錄頁面http://123.249.78.82:8080/login.html 一、項目背景 1.1 測試目標 驗證系統功能完整性,確保用戶管理、消息傳輸、好友管理等核心模塊符合需求。 1.2 項目技術棧 后端:Spring Boot/Spring MVC/WebSocket 數據…

RAGFlow與Dify的深度刨析

目錄 一、RAGFlow 框架 二、Dify 框架 三、兩者集成 四、深度對比 1. 核心定位對比 2. 核心功能對比 3. 技術架構對比 4. 部署與成本 5. 適用場景推薦 總結 一、RAGFlow 框架 RAGFlow 是一個專注于深度文檔理解和檢索增強生成(RAG)技術的框架…

CQF預備知識:一、微積分 -- 1.2.2 函數f(x)的類型詳解

文中內容僅限技術學習與代碼實踐參考,市場存在不確定性,技術分析需謹慎驗證,不構成任何投資建議。 📖 數學入門全解 本系列教程為CQF(國際量化金融分析師證書)認證所需的數學預備知識,涵蓋所有需要了解的數學基礎知識…

嵌入式工程師常用軟件

1、 Git Git 是公司常用的版本管理工具,人人都要會。在線的 git 教程可以參考菜鳥教程: https://www.runoob.com/git/git-tutorial.html 電子書教程請在搜索欄搜索: git Git 教程很多,常用的命令如下,這些命令可…

TReport組件指南總結

1. TReport 組件簡介 TReport 是一個用于生成和打印報表的組件,通常用于連接數據集(如 TDataSet)并設計復雜的報表布局。它支持動態數據綁定、多頁報表、分組統計、圖表插入等功能。 2. 安裝與配置 安裝:如果使用的是第三方報表工具(如 Rave Reports),需在 Delphi 中通…

spark任務的提交流程

目錄 spark任務的提交流程1. 資源申請與初始化2. 任務劃分與調度3. 任務執行4. 資源釋放與結果處理附:關鍵組件協作示意圖擴展說明SparkContext介紹 spark任務的提交流程 用戶創建一個 Spark Context;Spark Context 去找 Cluster Manager 申請資源同時說明需要多少 CPU 和內…

【C++】C++異步編程四劍客:future、async、promise和packaged_task詳解

C異步編程四劍客:future、async、promise和packaged_task詳解 1. 引言 1.1 異步編程的重要性 在現代C編程中,異步操作是提高程序性能和響應能力的關鍵技術。它允許程序在等待耗時操作(如I/O、網絡請求或復雜計算)完成時繼續執行…

2021-10-28 C++判斷完全平方數

緣由判斷一個整數是否為完全平方數-編程語言-CSDN問答 整數用平方法小數用5分法逼近。 int 判斷平方數(int n) {//緣由https://ask.csdn.net/questions/7546950?spm1005.2025.3001.5141int a 1;while (a < n / a)if (a*a < n)a;else if (a*a n)return 1;elsereturn 0…