從代碼學習深度學習 - LSTM PyTorch版

文章目錄

  • 前言
  • 一、數據加載與預處理
    • 1.1 代碼實現
    • 1.2 功能解析
  • 二、LSTM介紹
    • 2.1 LSTM原理
    • 2.2 模型定義
      • 代碼解析
  • 三、訓練與預測
    • 3.1 訓練邏輯
      • 代碼解析
    • 3.2 可視化工具
      • 功能解析
      • 功能結果
  • 總結


前言

深度學習中的循環神經網絡(RNN)及其變種長短期記憶網絡(LSTM)在處理序列數據(如文本、時間序列等)方面表現出色。本篇博客將通過一個完整的PyTorch實現,帶你從零開始學習如何使用LSTM進行文本生成任務。我們將基于H.G. Wells的《時間機器》數據集,逐步展示數據預處理、模型定義、訓練與預測的全過程。通過代碼和文字的結合,幫助你深入理解LSTM的實現細節及其在自然語言處理中的應用。

本文的代碼分為四個主要部分:

  1. 數據加載與預處理(utils_for_data.py
  2. LSTM模型定義(Jupyter Notebook中的模型部分)
  3. 訓練與預測邏輯(utils_for_train.py
  4. 可視化工具(utils_for_huitu.py

以下是詳細的實現與解析。


一、數據加載與預處理

首先,我們需要加載《時間機器》數據集并進行預處理。以下是utils_for_data.py中的完整代碼及其功能說明。

1.1 代碼實現

import random
import re
import torch
from collections import Counterdef read_time_machine():"""將時間機器數據集加載到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""將文本行拆分為單詞或字符詞元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'錯誤:未知詞元類型:{token}')def count_corpus(tokens):"""統計詞元的頻率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本詞表類,用于管理詞元及其索引的映射關系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocabdef seq_data_iter_random(corpus, batch_size, num_steps):offset = random.randint(0, num_steps - 1)corpus = corpus[offset:]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)def seq_data_iter_sequential(corpus, batch_size, num_steps):offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) *

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/74708.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/74708.shtml
英文地址,請注明出處:http://en.pswp.cn/web/74708.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

easy-poi 一對多導出

1. 需求&#xff1a; 某一列上下兩行單元格A,B值一樣且這兩個單元格&#xff0c; 前面所有列對應單元格值一樣的話&#xff0c; 就對A,B 兩個單元格進行縱向合并單元格 1. 核心思路&#xff1a; 先對數據集的國家&#xff0c;省份&#xff0c;城市...... id 身份證進行排序…

AI比人腦更強,因為被植入思維模型【42】思維投影思維模型

giszz的理解&#xff1a;本質和外在。我們的行為舉止&#xff0c;都是我們的內心的表現。從外邊可以看內心&#xff0c;從內心可以判斷外在。曾國藩有&#xff17;個識人的方法&#xff0c;大部分的人在他的面前如同沒穿衣服一樣。對于我們自身的啟迪&#xff0c;我認為有四點&…

Spring Boot 打印日志

1.通過slf4j包中的logger對象打印日志 Spring Boot內置了日志框架slf4j&#xff0c;在程序中調用slf4j來輸出日志 通過創建logger對象打印日志&#xff0c;Logger 對象是屬于 org.slf4j 包下的不要導錯包。 2.日志級別 日志級別從高到低依次為: FATAL:致命信息&#xff0c;表…

【IOS webview】源代碼映射錯誤,頁面卡住不動

報錯場景 safari頁面報源代碼映射錯誤&#xff0c;頁面卡住不動。 機型&#xff1a;IOS13 技術棧&#xff1a;react 其他IOS也會報錯&#xff0c;但不影響頁面顯示。 debug webpack配置不要GENERATE_SOURCEMAP。 解決方法&#xff1a; GENERATE_SOURCEMAPfalse react-app…

ES中經緯度查詢geo_point

0. ES版本 6.x版本 1. 創建索引 PUT /location {"settings": {"number_of_shards": 1,"number_of_replicas": 0},"mappings": {"location": {"properties": {"id": {"type": "keywor…

OpenCV界面編程

《OpenCV計算機視覺開發實踐&#xff1a;基于Python&#xff08;人工智能技術叢書&#xff09;》(朱文偉&#xff0c;李建英)【摘要 書評 試讀】- 京東圖書 OpenCV的Python開發環境搭建(Windows)-CSDN博客 OpenCV也支持有限的界面編程&#xff0c;主要是針對窗口、控件和鼠標…

GOC L2 第五課模運算和周期二

課堂回顧&#xff1a; 求取余數的過程叫做模運算 每輪的動作都是重復的&#xff0c;我們稱這個過程位周期。 課堂學習&#xff1a; 剩余計算器 秋天到了&#xff0c;學校里的蘋果熟了&#xff0c;太乙老師&#xff0c;想讓哪吒幫忙設計一個計算器&#xff0c;看每個小朋友能分…

54.大學生心理健康管理系統(基于springboot項目)

目錄 1.系統的受眾說明 2.相關技術 2.1 B/S結構 2.2 MySQL數據庫 3.系統分析 3.1可行性分析 3.1.1時間可行性 3.1.2 經濟可行性 3.1.3 操作可行性 3.1.4 技術可行性 3.1.5 法律可行性 3.2系統流程分析 3.3系統功能需求分析 3.4 系統非功能需求分析 4.系統設計…

Redis 除了數據類型外的核心功能 的詳細說明,包含事務、流水線、發布/訂閱、Lua 腳本的完整代碼示例和表格總結

以下是 Redis 除了數據類型外的核心功能 的詳細說明&#xff0c;包含事務、流水線、發布/訂閱、Lua 腳本的完整代碼示例和表格總結&#xff1a; 1. Redis 事務&#xff08;Transactions&#xff09; 功能描述 事務通過 MULTI 和 EXEC 命令將一組命令打包執行&#xff0c;保證…

STM32F103C8T6單片機硬核原理篇:討論GPIO的基本原理篇章1——只討論我們的GPIO簡單輸入和輸出

目錄 前言 輸出時的GPIO控制部分 標準庫是如何操作寄存器完成GPIO驅動的初始化的&#xff1f; 問題1&#xff1a;如何掌握GPIO的編程細節——跟寄存器如何打交道 問題2&#xff1a;哪些寄存器&#xff0c;去哪里找呢&#xff1f; 問題三&#xff0c;寄存器的含義&#xff…

前端布局難題:父元素padding導致子元素無法全屏?3種解決方案

大家好&#xff0c;我是一諾。今天要跟大家分享一個我在實際項目中經常用到的CSS技巧——如何讓子元素突破父元素的padding限制&#xff0c;實現真正的全屏寬度效果。 為什么會有這個需求&#xff1f; 記得我剛入行的時候&#xff0c;接到一個需求&#xff1a;要在內容區插入…

當網頁受到DDOS網絡攻擊有哪些應對方法?

分布式拒絕服務攻擊也是人們較為熟悉的DDOS攻擊&#xff0c;這類攻擊會通過大量受控制的僵尸網絡向目標服務器發送請求&#xff0c;以此來消耗服務器中的資源&#xff0c;致使用戶無法正常訪問&#xff0c;當網頁受到分布式拒絕服務攻擊時都有哪些應對方法呢&#xff1f; 建立全…

LeNet-5簡介及matlab實現

文章目錄 一、LeNet-5網絡結構簡介二、LeNet-5每一層的實現原理2.1. 第一層 (C1) &#xff1a;卷積層&#xff08;Convolution Layer&#xff09;2.2. 第二層 (S2) &#xff1a;池化層&#xff08;Pooling Layer&#xff09;2.3. 第三層&#xff08;C3&#xff09;&#xff1a;…

【LLM】MCP(Python):實現 stdio 通信的Client與Server

本文將詳細介紹如何使用 Model Context Protocol (MCP) 在 Python 中實現基于 STDIO 通信的 Client 與 Server。MCP 是一個開放協議&#xff0c;它使 LLM 應用與外部數據源和工具之間的無縫集成成為可能。無論你是構建 AI 驅動的 IDE、改善 chat 交互&#xff0c;還是構建自定義…

Docker 安裝 Elasticsearch 教程

目錄 一、安裝 Elasticsearch 二、安裝 Kibana 三、安裝 IK 分詞器 四、Elasticsearch 常用配置 五、Elasticsearch 常用命令 一、安裝 Elasticsearch &#xff08;一&#xff09;創建 Docker 網絡 因為后續還需要部署 Kibana 容器&#xff0c;所以需要讓 Elasticsearch…

Swagger @ApiOperation

ApiOperation 注解并非 Spring Boot 自帶的注解&#xff0c;而是來自 Swagger 框架&#xff0c;Swagger 是一個規范且完整的框架&#xff0c;用于生成、描述、調用和可視化 RESTful 風格的 Web 服務&#xff0c;而 ApiOperation 主要用于為 API 接口的操作添加描述信息。以下為…

【奇點時刻】GPT4o新圖像生成模型底層原理深度洞察報告(篇2)

由于上一篇解析深度不足&#xff0c;經過查看學習相關論文&#xff0c;以下是一份對 GPT-4o 最新的圖像生成模型 的深度梳理與洞察&#xff0c;從模型原理到社區解讀、對比傳統擴散模型&#xff0c;再到對未來趨勢的分析。為了便于閱讀&#xff0c;整理成以下七個部分&#xff…

C# 窗體應用(.FET Framework ) 打開文件操作

一、 打開文件或文件夾加載數據 1. 定義一個列表用來接收路徑 public List<string> paths new List<string>();2. 打開文件選擇一個文件并將文件放入列表中 OpenFileDialog open new OpenFileDialog(); // 過濾 open.Filter "(*.jpg;*.jpge;*.bmp;*.png…

Scala 面向對象編程總結

???抽象屬性和抽象方法 基本語法 定義抽象類&#xff1a;abstract class Person{} //通過 abstract 關鍵字標記抽象類定義抽象屬性&#xff1a;val|var name:String //一個屬性沒有初始化&#xff0c;就是抽象屬性定義抽象方法&#xff1a;def hello():String //只聲明而沒…

人工智能賦能工業制造:智能制造的未來之路

一、引言 隨著人工智能技術的飛速發展&#xff0c;其應用場景不斷拓展&#xff0c;從消費電子到醫療健康&#xff0c;從金融科技到交通運輸&#xff0c;幾乎涵蓋了所有行業。而工業制造作為國民經濟的支柱產業&#xff0c;也在人工智能的浪潮中迎來了深刻的變革。智能制造&…