基于LSTM的文本分類1——模型搭建

源碼

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置參數類,用于存儲模型和訓練的超參數"""def __init__(self, dataset, embedding):self.model_name = 'TextRNN'  # 模型名稱self.train_path = dataset + '/data/train.txt'  # 訓練集路徑self.dev_path = dataset + '/data/dev.txt'      # 驗證集路徑self.test_path = dataset + '/data/test.txt'    # 測試集路徑self.class_list = [x.strip() for x in open(dataset + '/data/class.txt').readlines()]  # 類別列表self.vocab_path = dataset + '/data/vocab.pkl'  # 詞表路徑self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'  # 模型保存路徑self.log_path = dataset + '/log/' + self.model_name  # 日志保存路徑# 加載預訓練詞向量(若提供)self.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32')) \if embedding != 'random' else Noneself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 訓練設備# 模型超參數self.dropout = 0.5              # 隨機失活率self.require_improvement = 1000 # 若超過該batch數效果未提升,則提前終止訓練self.num_classes = len(self.class_list)  # 類別數self.n_vocab = 0                # 詞表大小(運行時賦值)self.num_epochs = 10            # 訓練輪次self.batch_size = 128           # 批次大小self.pad_size = 32              # 句子填充/截斷長度self.learning_rate = 1e-3       # 學習率# 詞向量維度(使用預訓練時與預訓練維度對齊,否則設為300)self.embed = self.embedding_pretrained.size(1) \if self.embedding_pretrained is not None else 300self.hidden_size = 128          # LSTM隱藏層維度self.num_layers = 2             # LSTM層數'''基于LSTM的文本分類模型'''
class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()# 詞嵌入層:加載預訓練詞向量或隨機初始化if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)# 雙向LSTM層self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)# 全連接分類層self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)  # 雙向LSTM輸出維度翻倍def forward(self, x):x, _ = x  # 輸入x為(padded_seq, seq_len),此處取padded_seqout = self.embedding(x)  # [batch_size, seq_len, embed_dim]out, _ = self.lstm(out)  # LSTM輸出維度 [batch_size, seq_len, hidden_size*2]# 取最后一個時間步的輸出作為句子表示out = self.fc(out[:, -1, :])  # [batch_size, num_classes]return out

數據集

上圖是我們這次做的文本分類。一共十個話題領域,我們的目標是輸入一句話,模型能夠實現對話題領域的區分。

上圖是我們使用的數據集。前面的漢字部分是模型學習的文本,后面接一個tab鍵是對該文本的分類。

配置類

配置的重點是模型的超參數,這里分析一下模型涉及的超參數。

Dropout隨機失活率

self.dropout = 0.5

在LSTM層之間隨機屏蔽部分神經元輸出,強迫模型學習冗余特征表示。公式:hdrop=h⊙mhdrop?=h⊙m,其中mm是伯努利分布的0-1掩碼。

早停閾值

elf.require_improvement = 1000

早停閾值的思想是:連續N個batch在驗證集無精度提升則終止訓練。首次訓練數據的時候可能摸不清楚情況,設置了較大的epoch值,浪費掉大量訓練時間。假設batch_size=128,數據集1萬樣本 , 每個epoch大約有78個batch。1000個batch的耐心期大約是13個epoch。

序列填充長度

self.pad_size = 32

序列填充長度的作用是,將變長文本序列處理為固定長度,滿足神經網絡批量處理的要求 。如果文本長度小于32,則填充特定的字符。如果文本長度大于32,則進行截斷,保留32個字符。

序列填充長度通常使用95分位方式獲得,獲取代碼如下

import numpy as np
lengths = [len(text.split()) for text in train_texts]
pad_size = int(np.percentile(lengths, 95))  # 覆蓋95%樣本

詞向量維度

self.embed = 300

詞向量的維度決定了語義空間的自由度 。假設我們使用字分割,每個文字對應一個300維的向量,將向量輸入到模型中完成訓練。

可以得出,向量維數越多,可以包含的信息數量就越多。但是并不是維度越高越好,下面的表是高維和低維的對比。

因子低維(d=50)高維(d=1024)
語義區分度相似詞易混淆可學習細粒度差異
計算復雜度O(Vd) 內存占用低GPU顯存需求高
訓練數據需求1M+ tokens即可需100M+ tokens
下游任務適配性適合簡單分類任務適合語義匹配任務

由于我們的數據量較小,所以使用較低的詞向量維度。另外,如果使用預訓練模型,詞向量維度的值需要和預訓練模型的值相同。

LSTM隱藏層維度

self.hidden_size = 128

隱藏層維度先賣個關子,下一章LSTM模型解析的時候講。

模型搭建

Input Text → Embedding Layer → Bidirectional LSTM → Last Timestep Output → FC Layer → Class Probabilities

文本是無法直接被計算機識別的,所以文本需要映射為稠密向量才能輸入給模型。因此在輸入模型前要加一步向量映射。

class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()# 詞嵌入層:加載預訓練詞向量或隨機初始化if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)# 雙向LSTM層self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)# 全連接分類層self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)  # 雙向LSTM輸出維度翻倍def forward(self, x):x, _ = x  # 輸入x為(padded_seq, seq_len),此處取padded_seqout = self.embedding(x)  # [batch_size, seq_len, embed_dim]out, _ = self.lstm(out)  # LSTM輸出維度 [batch_size, seq_len, hidden_size*2]# 取最后一個時間步的輸出作為句子表示out = self.fc(out[:, -1, :])  # [batch_size, num_classes]return out

詞嵌入層

首先構建詞嵌入層,將本地的預訓練embedding加載到pytorch里面。

雙向LSTM層

我們使用雙向LSTM模型,即將文本從左到右訓練一次,也從右到左(倒著來)訓練一次。

參數名作用說明典型值
input_size輸入特征維度(等于嵌入維度)300
hidden_size隱藏層維度128/256
num_layersLSTM堆疊層數2-4
bidirectional啟用雙向LSTMTrue
batch_first輸入輸出使用(batch, seq, *)格式True
dropout層間dropout概率(僅當num_layers>1時生效)0.5

全連接分類層

self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

?全連接的輸入通道數是隱藏層維度的兩倍,原因是我們的模型是雙向的,雙向的結果都需要輸出給全連接層。

前向傳播

def forward(self, x):x, _ = x  # 解包(padded_seq, seq_len)out = self.embedding(x)  # [batch, seq_len, embed_dim]out, _ = self.lstm(out)  # [batch, seq_len, 2*hidden_size]out = self.fc(out[:, -1, :])  # 取最后時刻的輸出return out

首先提取輸入x的填充張量。可以看到張量里有4760這種值,這個值是我們在文字長度不夠時的填充內容。

?經過embedding映射后可以看到,張量out里的數據變成128*32*300的維度,300的維度就是詞向量維度,可以看到data里的數據都由原來的整數映射成了向量。

經過lstm運算后,out張量數據變成了128*32*128的維度

?最終經過全連接層,out張量變成了128*10維度的張量。128是batch_size,10個維度即代表該條數據在10個分類中的概率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/74184.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/74184.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/74184.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

小了 60,500 倍,但更強;AI 的“深度詛咒”

作者:Ignacio de Gregorio 圖片來自 Unsplash 的 Bahnijit Barman 幾周前,我們看到 Anthropic 嘗試訓練 Claude 去通關寶可夢。模型是有點進展,但離真正通關還差得遠。 但現在,一個獨立的小團隊用一個只有一千萬參數的模型通關了…

nextjs使用02

并行路由 同一個頁面,放多個路由,, 目錄前面加,layout中可以當作插槽引入 import React from "react";function layout({children,notifications,user}:{children:React.ReactNode,notifications:React.ReactNode,user:React.Re…

github 無法在shell里鏈接

當我在shell端git push時,我發現總是22 timeout的問題。 我就進行了以下步驟的嘗試并最終得到了解決。 第一步,我先確定我可以curl github,也就是我網絡沒問題 curl -v https://github.com 如果這個時候不超時和報錯,說明網絡…

當前主流的大模型知識庫軟件對比分析

以下是當前主流的大模型知識庫軟件對比分析,涵蓋功能特性、適用場景及優劣勢,結合最新技術動態和行業實踐提供深度選型參考: 一、企業級智能知識庫平臺 1. 阿里云百煉(Model Studio) 核心能力:基于RAG技…

Java的比較器 Comparable 和 Comparator

在 Java 中,Comparable 和 Comparator 是用于對象排序的重要接口。它們提供了不同的排序方式,適用于不同的需求,同時在 Java 底層排序算法中發揮著關鍵作用。本文將從基礎概念、使用方法、排序實現(包括升序、降序)、底…

基于Qlearning強化學習的太赫茲信道信號檢測與識別matlab仿真

目錄 1.算法仿真效果 2.算法涉及理論知識概要 2.1 太赫茲信道特性 2.2 Q-learning強化學習基礎 2.3 基于Q-learning 的太赫茲信道信號檢測與識別系統 3.MATLAB核心程序 4.完整算法代碼文件獲得 1.算法仿真效果 matlab2024b仿真結果如下(完整代碼運行后無水印…

力扣刷題————199.二叉樹的右視圖

給定一個二叉樹的 根節點 root,想象自己站在它的右側,按照從頂部到底部的順序,返回從右側所能看到的節點值。 示例 1: 輸入:root [1,2,3,null,5,null,4] 輸出:[1,3,4] 解題思路:我們可以想到這…

文件包含漏洞的小點總結

文件本地與遠程包含: 文件包含有本地包含與遠程包含的區別:本地包含只能包含服務器已經有的問題; 遠程包含可以包含一切網絡上的文件。 本地包含: ①無限制 感受一下使用phpstudy的文件上傳,開啟phpstudy的apache…

深度學習處理時間序列(5)

Keras中的循環層 上面的NumPy簡單實現對應一個實際的Keras層—SimpleRNN層。不過,二者有一點小區別:SimpleRNN層能夠像其他Keras層一樣處理序列批量,而不是像NumPy示例中的那樣只能處理單個序列。也就是說,它接收形狀為(batch_si…

操作系統相關知識點

操作系統在進行線程切換時需要進行哪些動作? 保存當前線程的上下文 保存寄存器狀態、保存棧信息。 調度器選擇下一個線程 調度算法決策:根據策略(如輪轉、優先級、公平共享)從就緒隊列選擇目標線程。 處理優先級:實時…

從0到1:Rust 如何用 FFmpeg 和 OpenGL 打造硬核視頻特效

引言:視頻特效開發的痛點,你中了幾個? 視頻特效如今無處不在:短視頻平臺的濾鏡美化、直播間的實時美顏、影視后期的電影級調色,甚至 AI 生成內容的動態效果。無論是個人開發者還是團隊,視頻特效都成了吸引…

【并發編程 | 第一篇】線程相關基礎知識

1.并發和并行有什么區別 并發是指多核CPU上的多任務處理,多個任務在同一時刻真正同時執行。 并行是指單核CPU上的多任務處理,多個任務在同一時間段內交替執行,通過時間片輪轉實現交替執行,用于解決IO密集型瓶頸。 如何理解線程安…

Kafka 偏移量

在 Apache Kafka 中,偏移量(Offset)是一個非常重要的概念。它不僅用于標識消息的位置,還在多種場景中發揮關鍵作用。本文將詳細介紹 Kafka 偏移量的核心概念及其使用場景。 一、偏移量的核心概念 1. 定義 偏移量是一個非負整數…

18.redis基本操作

Redis(Remote Dictionary Server)是一個開源的、高性能的鍵值對(Key-Value)存儲數據庫,廣泛應用于緩存、消息隊列、實時分析等場景。它以其極高的讀寫速度、豐富的數據結構和靈活的應用方式而受到開發者的青睞。 Redis 的主要特點 ?高性能: ?內存存儲:Redis 將所有數…

歷年跨鏈合約惡意交易詳解(一)——THORChain退款邏輯漏洞

漏洞合約函數 function returnVaultAssets(address router, address payable asgard, Coin[] memory coins, string memory memo) public payable {if (router address(this)){for(uint i 0; i < coins.length; i){_adjustAllowances(asgard, coins[i].asset, coins[i].a…

通俗易懂的講解SpringBean生命周期

&#x1f4d5;我是廖志偉&#xff0c;一名Java開發工程師、《Java項目實戰——深入理解大型互聯網企業通用技術》&#xff08;基礎篇&#xff09;、&#xff08;進階篇&#xff09;、&#xff08;架構篇&#xff09;清華大學出版社簽約作家、Java領域優質創作者、CSDN博客專家、…

深入理解 `git pull --rebase` 與 `--allow-unrelated-histories`:區別、原理與實戰指南

&#x1f680; git pull --rebase vs --allow-unrelated-histories 全面解析 在日常使用 Git 時&#xff0c;我們經常遇到兩種拉取遠程代碼的方式&#xff1a;git pull --rebase 和 git pull --allow-unrelated-histories。它們的區別是什么&#xff1f;各自適用哪些場景&…

Matlab_Simulink中導入CSV數據與仿真實現方法

前言 在Simulink仿真中&#xff0c;常需將外部數據&#xff08;如CSV文件或MATLAB工作空間變量&#xff09;作為輸入信號驅動模型。本文介紹如何高效導入CSV數據至MATLAB工作空間&#xff0c;并通過From Workspace模塊實現數據到Simulink的精確傳輸&#xff0c;適用于運動控制…

Spring Boot 中 JdbcTemplate 處理枚舉類型轉換 和 減少數據庫連接的方法 的詳細說明,包含代碼示例和關鍵要點

以下是 Spring Boot 中 JdbcTemplate 處理枚舉類型轉換 和 減少數據庫連接的方法 的詳細說明&#xff0c;包含代碼示例和關鍵要點&#xff1a; 一、JdbcTemplate 處理枚舉類型轉換 1. 場景說明 假設數據庫存儲的是枚舉的 String 或 int 值&#xff0c;但 Java 實體類使用 enu…

API 安全之認證鑒權

作者&#xff1a;半天 前言 API 作為企業的重要數字資源&#xff0c;在給企業帶來巨大便利的同時也帶來了新的安全問題&#xff0c;一旦被攻擊可能導致數據泄漏重大安全問題&#xff0c;從而給企業的業務發展帶來極大的安全風險。正是在這樣的背景下&#xff0c;OpenAPI 規范…