文章目錄
- 前言
- 一、數據預處理
- 二、輔助訓練工具函數
- 三、繪圖工具函數
- 四、模型定義
- 五、模型訓練與預測
- 六、實例化模型并訓練
- 訓練結果可視化
- 總結
前言
循環神經網絡(RNN)是深度學習中處理序列數據的重要模型,尤其在自然語言處理和時間序列分析中有著廣泛應用。本篇博客將通過一個基于 PyTorch 的 RNN 實現,結合《The Time Machine》數據集,帶你從零開始理解 RNN 的構建、訓練和預測過程。我們將逐步剖析代碼,展示如何加載數據、定義工具函數、構建模型、繪制訓練過程圖表,并最終訓練一個字符級別的 RNN 模型。代碼中包含了數據預處理、模型定義、梯度裁剪、困惑度計算等關鍵步驟,適合希望深入理解 RNN 的初學者和進階者。
本文基于 PyTorch 實現,所有代碼均來自附件,并輔以詳細注釋和圖表說明。讓我們開始吧!
一、數據預處理
首先,我們需要加載和預處理《The Time Machine》數據集,將其轉化為適合 RNN 輸入的格式。以下是數據預處理的完整代碼:
import random
import re
import torch
from collections import Counterdef read_time_machine():"""將時間機器數據集加載到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()# 去除非字母字符并將每行轉換為小寫return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""將文本行拆分為單詞或字符詞元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'錯誤:未知詞元類型:{token}')def count_corpus(tokens):"""統計詞元的頻率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本詞表類,用于管理詞元及其索引的映射關系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):"""返回時光機器數據集的詞元索引列表和詞表"""lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus