目錄
一、遺忘門(Forget Gate):決定 “該忘記什么”
二、輸入門(Input Gate):決定 “該記住什么新信息”
三、輸出門(Output Gate):決定 “該輸出什么”
四、候選記憶元(Candidate Cell State):“待存入的新信息”
五、記憶元(Cell State):長期記憶的 “倉庫”
六、 隱狀態(Hidden State):短期輸出與信息傳遞
七、 門控記憶元(Gated Cell):整體協同機制
八、各組件協同流程
九、為什么能解決長期依賴?
十、LSTM的結構圖
?十一、完整代碼
十二、實驗結果
一、遺忘門(Forget Gate):決定 “該忘記什么”
遺忘門的作用是篩選上一時刻記憶元中需要保留的信息。它根據 “上一時刻的隱狀態” 和 “當前輸入”,判斷哪些歷史信息可以被丟棄,哪些需要繼續保留。
- 輸入:上一時刻的隱狀態?
?+ 當前時間步的輸入?
;
- 計算過程: 先將兩者拼接,通過一個全連接層(權重為
,偏置為
),再經過 sigmoid 激活函數(輸出范圍 0~1):
其中,
是 sigmoid 函數(
,輸出?
是一個與記憶元同維度的向量(每個元素對應記憶元中的一個 “信息位”)。
- 含義:
中元素越接近 1:表示上一時刻記憶元中對應位置的信息 “完全保留”;
- 越接近 0:表示對應位置的信息 “完全遺忘”。
例:在句子 “我喜歡吃蘋果,不喜歡吃香蕉,……” 中,當讀到 “香蕉” 時,遺忘門會讓 “蘋果” 的相關信息適當保留(但權重可能降低),以便后續對比。
二、輸入門(Input Gate):決定 “該記住什么新信息”
輸入門的作用是篩選當前輸入中需要存入記憶元的新信息。它和遺忘門配合,共同完成記憶元的更新(先忘舊的,再記新的)。
- 輸入:同樣是上一時刻的隱狀態
?+ 當前輸入?
;
- 計算過程: 拼接后通過另一個全連接層(權重?
,偏置?
),再經 sigmoid 激活:
其中,
也是 0~1 的向量,每個元素對應 “當前輸入中該信息位是否允許進入記憶元”。
- 含義:
中元素越接近 1:表示當前輸入中對應位置的新信息 “允許存入記憶元”;
- 越接近 0:表示該新信息 “不存入記憶元”。
三、輸出門(Output Gate):決定 “該輸出什么”
輸出門的作用是從當前記憶元中篩選信息,生成當前時間步的隱狀態(即模型的 “當前輸出”)。隱狀態會傳遞到下一時間步,同時作為當前步的輸出(比如預測下一個詞)。
- 輸入:依然是?
;
- 計算過程: 拼接后通過第四個全連接層(權重?
,偏置?
),經 sigmoid 激活:
然后,用輸出門?
篩選當前記憶元
中的信息,再經 tanh 處理(確保輸出范圍 - 1~1):
- 含義:
?中元素越接近 1:表示記憶元中對應位置的信息 “允許輸出到隱狀態”;
- 越接近 0:表示該信息 “僅保存在記憶元中,不輸出”。
例:在 “我喜歡吃蘋果,它很甜” 中,當處理 “它” 時,輸出門會從記憶元中篩選 “蘋果” 的信息,讓隱狀態包含 “蘋果”,從而正確預測 “它” 指代 “蘋果”。
四、候選記憶元(Candidate Cell State):“待存入的新信息”
候選記憶元是當前輸入中可能被存入記憶元的 “原始新信息”(未經篩選),相當于 “草稿”,最終是否存入由輸入門決定。
- 輸入:還是?
;
- 計算過程: 拼接后通過第三個全連接層(權重?
,偏置?
),經 tanh 激活(輸出范圍 - 1~1):
- 為什么用 tanh:tanh 將值限制在 - 1~1 之間,避免新信息數值過大導致記憶元 “溢出”,同時保留正負信息(比如 “喜歡” vs “不喜歡”)。
五、記憶元(Cell State):長期記憶的 “倉庫”
記憶元是 LSTM 的 “核心倉庫”,負責存儲長期信息,其狀態會在時間步之間傳遞并被不斷更新。
- 更新規則:結合遺忘門(保留舊信息)和輸入門 + 候選記憶元(添加新信息):
其中,
?是元素級乘法(Hadamard 積)。
- 含義:
:對上一時刻的記憶元?
?按遺忘門的篩選保留部分信息;
:對候選記憶元?
?按輸入門的篩選保留部分新信息;
- 兩者相加:得到當前時刻的記憶元
(舊信息 + 新信息的融合)。
例:在 “我出生在巴黎,……,現在住在倫敦” 中,記憶元會先保留 “巴黎”(遺忘門讓其保留),當讀到 “倫敦” 時,輸入門允許 “倫敦” 進入,記憶元更新為 “巴黎 + 倫敦”(或根據重要性調整權重)。
六、 隱狀態(Hidden State):短期輸出與信息傳遞
隱狀態?是 LSTM 在當前時間步的 “對外輸出”,有兩個作用:
- 作為當前時間步的模型輸出(比如用于序列預測、分類等);
- 傳遞到下一時間步,作為計算下一個時間步各大門和候選記憶元的輸入。
與記憶元的區別:
- 記憶元?
:長期存儲,更新頻率低(主要保留關鍵信息);
- 隱狀態?
:短期輸出,隨時間步快速變化(反映當前時刻的重點信息)。
七、 門控記憶元(Gated Cell):整體協同機制
“門控記憶元” 不是一個獨立組件,而是對 LSTM 中 “記憶元 + 三大門控” 整體機制的統稱。它強調記憶元的更新和輸出是被輸入門、遺忘門、輸出門 “控制” 的,而非像傳統 RNN 那樣無差別傳遞。這種 “門控” 機制正是 LSTM 能處理長期依賴的核心。
八、各組件協同流程
為了更清晰理解,用一個時間步的流程總結:
- 遺忘舊信息:遺忘門
?決定從?
?中保留哪些信息;
- 篩選新信息:輸入門?
決定從候選記憶元?
?中保留哪些新信息;
- 更新記憶元:
(舊信息 + 新信息);
- 生成輸出:輸出門?
從?
中篩選信息,生成隱狀態
。
這個流程在每個時間步重復,使得記憶元能長期保留關鍵信息,隱狀態能靈活輸出當前重點,從而解決長期依賴問題。
- 遺忘門、輸入門、輸出門是 “控制器”,決定信息的刪、存、取;
- 候選記憶元是 “新信息草稿”;
- 記憶元是 “長期倉庫”;
- 隱狀態是 “當前輸出”。
九、為什么能解決長期依賴?
- 遺忘門的靈活性:可以讓不重要的信息快速遺忘(
),而關鍵信息長期保留(
);
- 記憶元的穩定性:記憶元的更新是 “加性” 的(
),而非 RNN 中隱狀態的 “替換性” 更新(
),梯度在反向傳播時更穩定,不易消失;
- 門控的選擇性:輸入門和輸出門可以 “按需” 添加和提取信息,避免無關信息干擾。
十、LSTM的結構圖
?十一、完整代碼
"""
文件名: 9.2 從零實現長短期記憶網絡(LSTM)
作者: 墨塵
日期: 2025/7/16
項目名: dl_env
備注:
"""
# -------------------------- 基礎工具庫導入 --------------------------
import collections # 用于統計詞頻(構建詞表時需統計每個詞元出現的次數)
import random # 隨機抽樣生成訓練數據(增加數據隨機性,提升模型泛化能力)
import re # 文本清洗(通過正則表達式過濾非目標字符)
import requests # 下載數據集(從網絡獲取《時間機器》文本數據)
from pathlib import Path # 文件路徑處理(創建目錄、檢查文件是否存在等)
from d2l import torch as d2l # 深度學習工具庫(提供訓練輔助、可視化等功能)
import math # 數學運算(計算困惑度等指標)
import torch # PyTorch框架(核心深度學習庫,提供張量運算、自動求導等)
from torch import nn # 神經網絡模塊(提供損失函數、層定義等)
from torch.nn import functional as F # 函數式API(提供激活函數、one-hot編碼等工具)# 圖像顯示相關庫(解決中文和符號顯示問題)
import matplotlib.pyplot as plt
import matplotlib.text as text# -------------------------- 核心解決方案:解決文本顯示問題 --------------------------
def replace_minus(s):"""解決Matplotlib中Unicode減號(U+2212)顯示為方塊的問題原理:將特殊減號替換為普通ASCII減號('-'),確保所有環境都能正常顯示"""if isinstance(s, str): # 僅處理字符串類型return s.replace('\u2212', '-') # 替換Unicode減號為ASCII減號return s # 非字符串直接返回# 重寫matplotlib的Text類的set_text方法,實現全局生效
original_set_text = text.Text.set_text # 保存原始方法(避免覆蓋后無法恢復)def new_set_text(self, s):s = replace_minus(s) # 先處理減號return original_set_text(self, s) # 調用原始方法設置文本text.Text.set_text = new_set_text # 應用重寫后的方法(所有文本顯示都會經過此處理)# -------------------------- 字體配置(確保中文和數學符號正常顯示)--------------------------
plt.rcParams["font.family"] = ["SimHei"] # 設置中文字體(SimHei支持中文顯示,避免中文亂碼)
plt.rcParams["text.usetex"] = True # 使用LaTeX渲染文本(提升數學符號顯示美觀度)
plt.rcParams["axes.unicode_minus"] = True # 確保負號正確顯示(避免負號顯示為方塊)
plt.rcParams["mathtext.fontset"] = "cm" # 數學符號使用Computer Modern字體(LaTeX標準字體,更專業)
d2l.plt.rcParams.update(plt.rcParams) # 讓d2l庫的繪圖工具繼承上述配置(保持顯示一致性)# -------------------------- 1. 讀取數據 --------------------------
def read_time_machine():"""下載并讀取《時間機器》數據集,返回清洗后的文本行列表作用:獲取原始文本數據并預處理,為后續詞元化做準備"""data_dir = Path('./data') # 數據存儲目錄(當前目錄下的data文件夾)data_dir.mkdir(exist_ok=True) # 目錄不存在則創建(exist_ok=True避免重復創建報錯)file_path = data_dir / 'timemachine.txt' # 數據集文件路徑# 檢查文件是否存在,不存在則下載if not file_path.exists():print("開始下載時間機器數據集...")# 從d2l官方地址下載文本(《時間機器》是經典數據集,適合語言模型訓練)response = requests.get('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')# 寫入文件(utf-8編碼確保兼容多種字符)with open(file_path, 'w', encoding='utf-8') as f:f.write(response.text)print(f"數據集下載完成,保存至: {file_path}")# 讀取文件并清洗文本with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines() # 按行讀取(每行作為列表元素)print(f"文件讀取成功,總行數: {len(lines)}")if len(lines) > 0:print(f"第一行內容: {lines[0].strip()}") # 打印首行驗證是否正確讀取# 清洗規則:# 1. re.sub('[^A-Za-z]+', ' ', line):保留字母,其他字符(如數字、符號)替換為空格# 2. strip():去除首尾空格# 3. lower():轉小寫(統一大小寫,減少詞元數量)# 4. 過濾空行(if line.strip()確保僅保留非空行)cleaned_lines = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines if line.strip()]print(f"清洗后有效行數: {len(cleaned_lines)}") # 清洗后非空行數量(去除純空格行)return cleaned_lines# -------------------------- 2. 詞元化與詞表構建 --------------------------
def tokenize(lines, token='char'):"""將文本行轉換為詞元列表(詞元是文本的最小處理單位)參數:lines: 清洗后的文本行列表(如["abc def", "ghi jkl"])token: 詞元類型('char'字符級/'word'單詞級)返回:詞元列表(如字符級:[['a','b','c',' ','d','e','f'], ...])作用:將文本拆分為模型可處理的最小單元(詞元),字符級適合簡單語言模型"""if token == 'char':# 字符級詞元化:將每行拆分為單個字符列表(包括空格,如"abc"→['a','b','c'])return [list(line) for line in lines]elif token == 'word':# 單詞級詞元化:按空格拆分每行(需確保文本已用空格分隔單詞,如"abc def"→['abc','def'])return [line.split() for line in lines]else:raise ValueError('未知詞元類型:' + token)class Vocab:"""詞表類:實現詞元與索引的雙向映射,用于將文本轉換為模型可處理的數字序列核心功能:將字符串形式的詞元轉換為整數索引(模型只能處理數字),同時支持索引轉詞元(用于生成文本)"""def __init__(self, tokens, min_freq=0, reserved_tokens=None):"""構建詞表參數:tokens: 詞元列表(可嵌套,如[[token1, token2], [token3]])min_freq: 最低詞頻閾值(低于此值的詞元不加入詞表,減少詞匯量)reserved_tokens: 預留特殊詞元(如分隔符、填充符等,模型可能需要的特殊標記)"""if reserved_tokens is None:reserved_tokens = [] # 默認為空(無預留詞元)# 統計詞頻:# 1. 展平嵌套列表([token for line in tokens for token in line])# 2. 用Counter計數(得到{詞元: 出現次數}字典)counter = collections.Counter([token for line in tokens for token in line])# 按詞頻降序排序(便于后續按頻率篩選,高頻詞優先保留)self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化詞表:# <unk>(未知詞元)固定在索引0(所有未見過的詞元都映射到<unk>)# followed by預留詞元(如用戶指定的特殊標記)self.idx_to_token = ['<unk>'] + reserved_tokens# 構建詞元到索引的映射(字典,便于快速查詢)self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 按詞頻添加詞元(過濾低頻詞)for token, freq in self.token_freqs:if freq < min_freq:break # 低頻詞不加入詞表(提前終止,提升效率)if token not in self.token_to_idx: # 避免重復添加預留詞元(如預留詞元已在列表中)self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1 # 索引為當前長度-1(保持連續)def __len__(self):"""返回詞表大小(詞元總數,用于模型輸入/輸出維度設置)"""return len(self.idx_to_token)def __getitem__(self, tokens):"""詞元→索引(支持單個詞元或詞元列表)未知詞元返回<unk>的索引(0),確保模型輸入始終有效"""if not isinstance(tokens, (list, tuple)):# 單個詞元:查字典,默認返回<unk>的索引(0)return self.token_to_idx.get(tokens, self.unk)# 詞元列表:遞歸轉換每個詞元(如['a','b']→[2,3])return [self.__getitem__(token) for token in tokens]def to_tokens(self, indices):"""索引→詞元(支持單個索引或索引列表,用于將模型輸出轉換為文本)"""if not isinstance(indices, (list, tuple)):# 單個索引:直接查列表(如2→'a')return self.idx_to_token[indices]# 索引列表:遞歸轉換每個索引(如[2,3]→['a','b'])return [self.idx_to_token[index] for index in indices]@propertydef unk(self):"""返回<unk>的索引(固定為0,便于統一處理未知詞元)"""return 0# -------------------------- 3. 數據迭代器(隨機抽樣) --------------------------
def seq_data_iter_random(corpus, batch_size, num_steps):"""隨機抽樣生成批量子序列(生成器),用于模型訓練的批量輸入原理:從語料中隨機截取多個長度為num_steps的子序列,組成批次(避免模型學習到固定的句子順序)參數:corpus: 詞元索引序列(1D列表,如[1,3,5,2,...],所有文本的詞元索引拼接而成)batch_size: 批量大小(每個批次包含的子序列數,影響訓練效率和內存占用)num_steps: 子序列長度(時間步,即模型一次處理的序列長度,如35表示一次輸入35個詞元)返回:生成器,每次返回(X, Y):X: 輸入序列(batch_size, num_steps),模型的輸入Y: 標簽序列(batch_size, num_steps),是X右移一位的結果(模型需要預測的下一個詞元)"""# 檢查數據是否足夠生成至少一個子序列(子序列長度+1,因Y是X右移1位,需多1個元素)if len(corpus) < num_steps + 1:raise ValueError(f"語料庫長度({len(corpus)})不足,需至少{num_steps + 1}")# 隨機偏移起始位置(0到num_steps-1),增加數據隨機性(避免每次從固定位置開始)corpus = corpus[random.randint(0, num_steps - 1):]# 計算可生成的子序列總數:# (語料長度-1) // num_steps(-1是因Y需多1個元素,每個子序列需num_steps+1個元素)num_subseqs = (len(corpus) - 1) // num_stepsif num_subseqs < 1:raise ValueError(f"無法生成子序列(語料庫長度不足)")# 生成所有子序列的起始索引(間隔為num_steps,如0, num_steps, 2*num_steps...)initial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices) # 打亂起始索引,實現隨機抽樣(核心:避免子序列順序固定)# 計算可生成的批次數:子序列總數 // 批量大小(確保每個批次有batch_size個子序列)num_batches = num_subseqs // batch_sizeif num_batches < 1:raise ValueError(f"子序列數量({num_subseqs})不足,需至少{batch_size}個")# 生成批量數據for i in range(0, batch_size * num_batches, batch_size):# 當前批次的起始索引(從打亂的索引中取batch_size個,如i=0時取前batch_size個)indices = initial_indices[i: i + batch_size]# 輸入序列X:每個子序列從indices[j]開始,取num_steps個元素(如indices[j]=0→[0:35])X = [corpus[j: j + num_steps] for j in indices]# 標簽序列Y:每個子序列從indices[j]+1開始,取num_steps個元素(X右移1位,如[1:36])Y = [corpus[j + 1: j + num_steps + 1] for j in indices]# 轉換為張量返回(便于模型處理,PyTorch模型輸入需為張量)yield torch.tensor(X), torch.tensor(Y)# -------------------------- 4. 數據加載函數(關鍵修復:返回可重置的迭代器) --------------------------
def load_data_time_machine(batch_size, num_steps):"""加載《時間機器》數據,返回數據迭代器生成函數和詞表修復點:返回迭代器生成函數(而非一次性迭代器),確保訓練時可重復生成數據(每個epoch重新抽樣)參數:batch_size: 批量大小num_steps: 子序列長度(時間步)返回:data_iter: 迭代器生成函數(調用時返回新的迭代器,每次調用重新抽樣)vocab: 詞表對象(用于詞元與索引的轉換)"""lines = read_time_machine() # 讀取清洗后的文本行tokens = tokenize(lines, token='char') # 字符級詞元化(每個字符為詞元,適合簡單語言模型)vocab = Vocab(tokens) # 構建詞表(根據詞元生成索引映射)# 將所有詞元轉換為索引(展平為1D序列,如[[ 'a', 'b' ], [ 'c' ]]→[2,3,4])corpus = [vocab[token] for line in tokens for token in line]print(f"語料庫長度: {len(corpus)}(詞元索引總數)")# 定義迭代器生成函數:每次調用生成新的隨機抽樣迭代器(確保每個epoch數據不同)def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)return data_iter, vocab # 返回生成函數和詞表# -------------------------- 5. LSTM模型核心實現 --------------------------def get_lstm_params(vocab_size, num_hiddens, device):"""初始化LSTM的所有參數(權重和偏置)參數:vocab_size: 詞表大小(輸入/輸出維度,因是語言模型,輸入輸出均為詞表詞元)num_hiddens: 隱藏層維度(記憶元/隱狀態的維度,控制模型容量)device: 計算設備(CPU/GPU,參數需存儲在對應設備上)返回:參數列表:包含所有門控、候選記憶元、輸出層的權重和偏置"""num_inputs = num_outputs = vocab_size # 輸入維度=輸出維度=詞表大小def normal(shape):"""生成正態分布的隨機參數(均值0,標準差0.01,避免初始值過大)"""return torch.randn(size=shape, device=device) * 0.01def three():"""生成一組參數(輸入權重、隱藏層權重、偏置),用于門控或候選記憶元"""return (normal((num_inputs, num_hiddens)), # 輸入X的權重(vocab_size × num_hiddens)normal((num_hiddens, num_hiddens)), # 上一時刻隱狀態H的權重(num_hiddens × num_hiddens)torch.zeros(num_hiddens, device=device)) # 偏置(初始為0,num_hiddens維度)# 輸入門參數(W_xi:輸入X到輸入門的權重;W_hi:上一H到輸入門的權重;b_i:偏置)W_xi, W_hi, b_i = three()# 遺忘門參數(W_xf:輸入X到遺忘門的權重;W_hf:上一H到遺忘門的權重;b_f:偏置)W_xf, W_hf, b_f = three()# 輸出門參數(W_xo:輸入X到輸出門的權重;W_ho:上一H到輸出門的權重;b_o:偏置)W_xo, W_ho, b_o = three()# 候選記憶元參數(W_xc:輸入X到候選記憶元的權重;W_hc:上一H到候選記憶元的權重;b_c:偏置)W_xc, W_hc, b_c = three()# 輸出層參數(將隱狀態H映射到輸出詞表維度)W_hq = normal((num_hiddens, num_outputs)) # H到輸出的權重(num_hiddens × vocab_size)b_q = torch.zeros(num_outputs, device=device) # 輸出層偏置# 附加梯度(所有參數需要計算梯度,后續訓練時更新)params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True) # 啟用梯度計算return paramsdef init_lstm_state(batch_size, num_hiddens, device):"""初始化LSTM的初始狀態(記憶元和隱狀態)LSTM有兩個狀態:記憶元(Cell State,長期記憶)和隱狀態(Hidden State,短期輸出)初始狀態均為全0張量"""return (torch.zeros((batch_size, num_hiddens), device=device), # 記憶元c的初始狀態(batch_size × num_hiddens)torch.zeros((batch_size, num_hiddens), device=device)) # 隱狀態h的初始狀態(batch_size × num_hiddens)def lstm(inputs, state, params):"""LSTM前向傳播(核心計算邏輯)參數:inputs: 輸入序列(num_steps, batch_size, vocab_size),每個時間步的輸入(one-hot編碼)state: 初始狀態((H_0, C_0),H_0是初始隱狀態,C_0是初始記憶元)params: LSTM的所有參數(門控、候選記憶元、輸出層的權重和偏置)返回:outputs: 所有時間步的輸出拼接(num_steps*batch_size, vocab_size)(H, C): 最終的隱狀態和記憶元(用于傳遞到下一批次或預測)"""# 解析參數(從params列表中提取各部分參數)[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = state # 初始狀態:H是上一時刻隱狀態,C是上一時刻記憶元outputs = [] # 存儲每個時間步的輸出# 逐時間步計算(inputs的第0維是時間步)for X in inputs:# 1. 計算輸入門(I_t):控制新信息進入記憶元的比例(0~1)# 輸入門由當前輸入X和上一隱狀態H共同決定,sigmoid激活(輸出0~1)I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) # X×W_xi:輸入X的貢獻;H×W_hi:上一H的貢獻;加偏置后激活# 2. 計算遺忘門(F_t):控制記憶元中舊信息保留的比例(0~1)# 同樣由X和H決定,sigmoid激活F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 3. 計算輸出門(O_t):控制記憶元中信息輸出到隱狀態的比例(0~1)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 4. 計算候選記憶元(C_tilda):當前時間步的新信息(-1~1)# tanh激活確保值在-1~1之間,避免數值過大C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 5. 更新記憶元(C_t):舊信息保留 + 新信息加入# F×C:遺忘門控制舊記憶元C保留的部分;I×C_tilda:輸入門控制新信息加入的部分C = F * C + I * C_tilda# 6. 更新隱狀態(H_t):輸出門控制記憶元中信息的輸出# tanh(C)將記憶元值縮放到-1~1,再由輸出門O篩選H = O * torch.tanh(C)# 7. 計算當前時間步的輸出(Y_t):隱狀態H映射到詞表維度Y = (H @ W_hq) + b_q # H×W_hq:隱狀態到輸出的映射;加偏置outputs.append(Y) # 保存當前時間步的輸出# 拼接所有時間步的輸出(按時間步維度拼接),返回輸出和最終狀態return torch.cat(outputs, dim=0), (H, C)# -------------------------- 6. RNN模型包裝類 --------------------------
class RNNModelScratch: # @save"""從零實現的RNN模型包裝類,統一模型調用接口(適配訓練和預測流程)"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):"""參數:vocab_size: 詞表大小(輸入/輸出維度)num_hiddens: 隱藏層維度(記憶元/隱狀態的維度)device: 計算設備get_params: 參數初始化函數(如get_lstm_params)init_state: 狀態初始化函數(如init_lstm_state)forward_fn: 前向傳播函數(如lstm)"""self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device) # 模型參數(通過get_params獲取)self.init_state, self.forward_fn = init_state, forward_fn # 狀態初始化和前向傳播函數def __call__(self, X, state):"""模型調用接口(前向傳播入口,兼容PyTorch的調用方式)參數:X: 輸入序列(batch_size, num_steps),元素為詞元索引(未編碼的原始輸入)state: 初始隱藏狀態((H_0, C_0))返回:y_hat: 輸出(num_steps*batch_size, vocab_size),所有時間步的輸出拼接state: 最終隱藏狀態((H_t, C_t))"""# 處理輸入:# 1. X.T:轉置為(num_steps, batch_size)(便于逐時間步處理,時間步在前)# 2. F.one_hot:轉換為one-hot編碼(num_steps, batch_size, vocab_size),將索引轉為向量# 3. type(torch.float32):轉換為浮點型(適配后續矩陣運算,權重為浮點型)X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 調用前向傳播函數(如lstm)計算輸出和新狀態return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):"""獲取初始隱藏狀態(調用初始化函數,封裝狀態初始化邏輯)"""return self.init_state(batch_size, self.num_hiddens, device)# -------------------------- 7. 預測函數(文本生成) --------------------------
def predict_ch8(prefix, num_preds, net, vocab, device): # @save"""根據前綴生成后續字符(文本生成,驗證模型學習效果)參數:prefix: 前綴字符串(如"time traveller",模型基于此生成后續內容)num_preds: 要生成的字符數net: 訓練好的LSTM模型vocab: 詞表(用于詞元與索引的轉換)device: 計算設備返回:生成的字符串(前綴+預測字符,如前綴"ti"生成"time...")"""# 初始化狀態(批量大小為1,因僅生成一條序列,無需并行)state = net.begin_state(batch_size=1, device=device)# 記錄輸出索引:初始為前綴首字符的索引(將前綴轉換為索引序列)outputs = [vocab[prefix[0]]]# 輔助函數:獲取當前輸入(最后一個輸出的索引,形狀(1,1),符合模型輸入格式)def get_input():return torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 預熱期:用前綴更新模型狀態(不生成新字符,僅讓模型"記住"前綴的信息)for y in prefix[1:]:_, state = net(get_input(), state) # 前向傳播,更新狀態(忽略輸出,因只需狀態)outputs.append(vocab[y]) # 記錄前綴字符的索引(確保outputs包含完整前綴)# 預測期:生成num_preds個字符for _ in range(num_preds):y, state = net(get_input(), state) # 前向傳播,獲取輸出和新狀態(y是當前時間步的輸出)# 取概率最大的字符索引(貪婪采樣:簡單策略,選擇模型認為最可能的下一個字符)outputs.append(int(y.argmax(dim=1).reshape(1)))# 將索引轉換為字符,拼接成字符串返回(完成從索引到文本的轉換)return ''.join([vocab.idx_to_token[i] for i in outputs])# -------------------------- 8. 梯度裁剪(防止梯度爆炸) --------------------------
def grad_clipping(net, theta): # @save"""裁剪梯度(將梯度L2范數限制在theta內),防止梯度爆炸(RNN訓練中常見問題)原理:若梯度范數超過閾值theta,則按比例縮小所有梯度,確保訓練穩定參數:net: 模型(自定義模型或nn.Module)theta: 梯度閾值(如1.0,根據經驗設置)"""# 獲取需要梯度更新的參數if isinstance(net, nn.Module):# 若為PyTorch官方Module,直接取parameters(包含所有需要梯度的參數)params = [p for p in net.parameters() if p.requires_grad]else:# 若為自定義模型(如RNNModelScratch),取params屬性(存儲模型參數)params = net.params# 計算所有參數梯度的L2范數(平方和開根號)norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta: # 若范數超過閾值,按比例裁剪(保持梯度方向不變,縮小幅度)for param in params:param.grad[:] *= theta / norm# -------------------------- 9. 訓練函數 --------------------------
def train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter):"""訓練一個周期(單輪遍歷數據集)參數:net: LSTM模型train_iter_fn: 迭代器生成函數(調用后返回新迭代器,每個epoch重新生成數據)loss: 損失函數(如CrossEntropyLoss,計算預測與標簽的差距)updater: 優化器(如SGD,用于更新模型參數)device: 計算設備use_random_iter: 是否使用隨機抽樣(影響狀態處理方式:隨機抽樣時狀態獨立,無需傳遞)返回:ppl: 困惑度(perplexity,語言模型性能指標,越低表示模型越好)speed: 訓練速度(詞元/秒,衡量訓練效率)"""state, timer = None, d2l.Timer() # 初始化狀態和計時器(timer用于計算訓練速度)metric = d2l.Accumulator(2) # 累加器:(總損失, 總詞元數),用于計算平均損失batches_processed = 0 # 記錄處理的批次數量(驗證是否有數據被處理)# 關鍵修復:每次訓練都通過函數生成新的迭代器(避免迭代器被提前消費,確保每個epoch數據不同)train_iter = train_iter_fn()# 遍歷批量數據(每個X, Y是一個批次)for X, Y in train_iter:batches_processed += 1# 初始化狀態:# - 首次迭代時需初始化(state為None)# - 隨機抽樣時,每個批次的序列獨立(無上下文關聯),需重新初始化if state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:# 非隨機抽樣時,分離狀態(切斷梯度回流到之前的批次,避免梯度計算依賴過長導致爆炸)if isinstance(net, nn.Module) and not isinstance(state, tuple):state.detach_() # 單個狀態直接detach(如GRU只有隱狀態)else:for s in state: # 多個狀態(如LSTM有隱狀態和記憶元)逐個detachs.detach_()# 處理標簽:# Y.T.reshape(-1):轉置后展平為(num_steps*batch_size,)(與輸出y_hat的形狀匹配)# 輸出y_hat的形狀是(num_steps*batch_size, vocab_size),標簽需為1D張量y = Y.T.reshape(-1)# 將輸入和標簽移到目標設備(GPU/CPU,確保與模型參數在同一設備)X, y = X.to(device), y.to(device)# 前向傳播:獲取輸出和新狀態y_hat, state = net(X, state)# 計算損失(mean()是因損失函數可能返回每個樣本的損失,取平均得到批次損失)l = loss(y_hat, y.long()).mean()# 反向傳播與參數更新:if isinstance(updater, torch.optim.Optimizer):# 若為PyTorch優化器(如SGD)updater.zero_grad() # 清零梯度(避免梯度累積)l.backward() # 反向傳播計算梯度grad_clipping(net, 1) # 裁剪梯度(閾值1,防止梯度爆炸)updater.step() # 更新參數else:# 若為自定義優化器(如d2l的sgd函數)l.backward()grad_clipping(net, 1)updater(batch_size=1) # 假設批量大小為1的更新(簡化實現)# 累加總損失和總詞元數(用于計算平均損失)# metric[0] += l * y.numel():總損失=批次損失×詞元數(因l是平均損失)# metric[1] += y.numel():總詞元數=累加每個批次的詞元數量metric.add(l * y.numel(), y.numel())# 檢查是否有批次被處理(避免空迭代導致的錯誤)if batches_processed == 0:print("警告:沒有處理任何訓練批次!")return float('inf'), 0# 計算困惑度(perplexity = exp(平均損失),語言模型專用指標,與交叉熵損失正相關)# 平均損失 = 總損失 / 總詞元數,exp后得到困惑度(完美模型困惑度=1)# 速度 = 總詞元數 / 訓練時間(詞元/秒,衡量訓練效率)return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()def train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=False):"""訓練模型(多周期,整合單周期訓練邏輯,輸出訓練過程和結果)參數:net: LSTM模型train_iter_fn: 迭代器生成函數vocab: 詞表lr: 學習率(控制參數更新幅度)num_epochs: 訓練周期數(遍歷數據集的次數,影響模型收斂程度)device: 計算設備use_random_iter: 是否使用隨機抽樣(默認False,即順序抽樣)"""loss = nn.CrossEntropyLoss() # 交叉熵損失(適用于分類任務,此處為詞元預測,多分類問題)# 動畫器:可視化訓練過程(實時繪制困惑度隨周期變化的曲線,直觀觀察模型收斂情況)animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化優化器:if isinstance(net, nn.Module):# 若為PyTorch Module,使用SGD優化器(隨機梯度下降,適合簡單模型)updater = torch.optim.SGD(net.parameters(), lr)else:# 若為自定義模型,使用d2l的sgd函數(簡化的隨機梯度下降實現)updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)# 定義預測函數:根據前綴"time traveller"生成50個字符(驗證模型學習效果)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 多周期訓練for epoch in range(num_epochs):# 訓練一個周期,返回困惑度和速度ppl, speed = train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter)# 每10個周期打印一次預測結果(觀察生成文本質量變化,判斷模型是否學到有意義的模式)if (epoch + 1) % 10 == 0:print(f"epoch {epoch + 1} 預測: {predict('time traveller')}")animator.add(epoch + 1, [ppl]) # 記錄困惑度,更新動畫# 訓練結束后輸出最終結果(總結模型性能)print(f'最終困惑度 {ppl:.1f}, 速度 {speed:.1f} 詞元/秒 {device}')print(f"time traveller 預測: {predict('time traveller')}") # 用"time traveller"前綴生成文本print(f"traveller 預測: {predict('traveller')}") # 用"traveller"前綴生成文本# -------------------------- 主程序 --------------------------
if __name__ == '__main__':# 超參數設置(根據經驗和任務調整)batch_size, num_steps = 32, 35 # 批量大小=32(每次處理32個序列),時間步=35(每個序列35個詞元)# 加載數據:獲取迭代器生成函數和詞表train_iter, vocab = load_data_time_machine(batch_size, num_steps)# 模型參數vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() # 詞表大小、隱藏層維度=256,自動選擇GPU/CPUnum_epochs, lr = 500, 0.12 # 訓練周期=500(充分訓練),學習率=0.12(控制更新幅度)# 初始化LSTM模型(使用自定義的從零實現的模型)model = RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)# 開始訓練(調用訓練函數,啟動多周期訓練)train_ch8(model, train_iter, vocab, lr, num_epochs, device)plt.show(block=True) # 顯示訓練過程的動畫圖(阻塞模式,確保圖不閃退,便于觀察)