文章目錄
- 前言
- 一、數據加載與預處理
- 1.1 代碼實現
- 1.2 功能解析
- 二、LSTM介紹
- 2.1 LSTM原理
- 2.2 模型定義
- 代碼解析
- 三、訓練與預測
- 3.1 訓練邏輯
- 代碼解析
- 3.2 可視化工具
- 功能解析
- 功能結果
- 總結
前言
深度學習中的循環神經網絡(RNN)及其變種長短期記憶網絡(LSTM)在處理序列數據(如文本、時間序列等)方面表現出色。本篇博客將通過一個完整的PyTorch實現,帶你從零開始學習如何使用LSTM進行文本生成任務。我們將基于H.G. Wells的《時間機器》數據集,逐步展示數據預處理、模型定義、訓練與預測的全過程。通過代碼和文字的結合,幫助你深入理解LSTM的實現細節及其在自然語言處理中的應用。
本文的代碼分為四個主要部分:
- 數據加載與預處理(
utils_for_data.py
) - LSTM模型定義(Jupyter Notebook中的模型部分)
- 訓練與預測邏輯(
utils_for_train.py
) - 可視化工具(
utils_for_huitu.py
)
以下是詳細的實現與解析。
一、數據加載與預處理
首先,我們需要加載《時間機器》數據集并進行預處理。以下是utils_for_data.py
中的完整代碼及其功能說明。
1.1 代碼實現
import random
import re
import torch
from collections import Counterdef read_time_machine():"""將時間機器數據集加載到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""將文本行拆分為單詞或字符詞元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'錯誤:未知詞元類型:{token}')def count_corpus(tokens):"""統計詞元的頻率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本詞表類,用于管理詞元及其索引的映射關系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocabdef seq_data_iter_random(corpus, batch_size, num_steps):offset = random.randint(0, num_steps - 1)corpus = corpus[offset:]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)def seq_data_iter_sequential(corpus, batch_size, num_steps):offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) *