深度學習筆記40_中文文本分類-Pytorch實現

  • 🍨 本文為🔗365天深度學習訓練營?中的學習記錄博客
  • 🍖 原作者:K同學啊 | 接輔導、項目定制

一、我的環境

1.語言環境:Python 3.8

2.編譯器:Pycharm

3.深度學習環境:

  • torch==1.12.1+cu113
  • torchvision==0.13.1+cu113

、導入數據

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")import pandas as pd# 加載自定義中文數據
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
print(train_data.head())

結果:

                       0              1
0      還有雙鴨山到淮陰的汽車票嗎13號的   Travel-Query
1                從這里怎么回家   Travel-Query
2       隨便播放一首專輯閣樓里的佛里的歌     Music-Play
3              給看一下墓王之王嘛  FilmTele-Play
4  我想看挑戰兩把s686打突變團競的游戲視頻     Video-Play

、構建詞典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分詞方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 設置默認索引,如果找不到單詞,則會選擇默認索引print(vocab(['我','想','看','和平','精英','上','戰神','必備','技巧','的','游戲','視頻']))

結果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看和平精英上戰神必備技巧的游戲視頻'))
print(label_pipeline('Video-Play'))
結果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
4

生成數據批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 標簽列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即語句的總詞匯量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回維度dim中輸入元素的累計和return text_list.to(device), label_list.to(device), offsets.to(device)# 數據加載器,調用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)

定義模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,  # 詞典大小embed_dim,  # 嵌入的維度sparse=False)  #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)  # 初始化權重self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置值歸零def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

定義實例

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定義訓練函數與評估函數

import timedef train(dataloader):model.train()  # 切換為訓練模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()  # grad屬性歸零loss = criterion(predicted_label, label)  # 計算網絡輸出和真實值之間的差距,label為真實值loss.backward()  # 反向傳播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪optimizer.step()  # 每一步自動更新# 記錄acc與losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切換為測試模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 計算loss值# 記錄測試數據total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc / total_count, train_loss / total_count

訓練模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset# 超參數
EPOCHS = 10  # epoch
LR = 5  # 學習率
BATCH_SIZE = 64  # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 構建數據集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 獲取當前的學習率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc, val_loss, lr))print('-' * 69)

?結果:

Batch [50/152], Loss: 0.0340, Accuracy: 0.4203
Batch [100/152], Loss: 0.0235, Accuracy: 0.5851
Batch [150/152], Loss: 0.0309, Accuracy: 0.6572
---------------------------------------------------------------------
| epoch 1 | time: 0.55s | valid_acc 0.814 valid_loss 0.012 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0104, Accuracy: 0.8165
Batch [100/152], Loss: 0.0099, Accuracy: 0.8215
Batch [150/152], Loss: 0.0092, Accuracy: 0.8329
---------------------------------------------------------------------
| epoch 2 | time: 0.44s | valid_acc 0.855 valid_loss 0.008 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0068, Accuracy: 0.8790
Batch [100/152], Loss: 0.0065, Accuracy: 0.8778
Batch [150/152], Loss: 0.0064, Accuracy: 0.8809
---------------------------------------------------------------------
| epoch 3 | time: 0.44s | valid_acc 0.874 valid_loss 0.007 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0050, Accuracy: 0.9105
Batch [100/152], Loss: 0.0051, Accuracy: 0.9101
Batch [150/152], Loss: 0.0048, Accuracy: 0.9130
---------------------------------------------------------------------
| epoch 4 | time: 0.44s | valid_acc 0.882 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0039, Accuracy: 0.9366
Batch [100/152], Loss: 0.0039, Accuracy: 0.9339
Batch [150/152], Loss: 0.0038, Accuracy: 0.9350
---------------------------------------------------------------------
| epoch 5 | time: 0.44s | valid_acc 0.896 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0028, Accuracy: 0.9519
Batch [100/152], Loss: 0.0030, Accuracy: 0.9517
Batch [150/152], Loss: 0.0030, Accuracy: 0.9494
---------------------------------------------------------------------
| epoch 6 | time: 0.44s | valid_acc 0.898 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0025, Accuracy: 0.9580
Batch [100/152], Loss: 0.0024, Accuracy: 0.9616
Batch [150/152], Loss: 0.0024, Accuracy: 0.9609
---------------------------------------------------------------------
| epoch 7 | time: 0.44s | valid_acc 0.902 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0018, Accuracy: 0.9764
Batch [100/152], Loss: 0.0019, Accuracy: 0.9739
Batch [150/152], Loss: 0.0019, Accuracy: 0.9724
---------------------------------------------------------------------
| epoch 8 | time: 0.44s | valid_acc 0.900 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0015, Accuracy: 0.9810
Batch [100/152], Loss: 0.0014, Accuracy: 0.9817
Batch [150/152], Loss: 0.0014, Accuracy: 0.9818
---------------------------------------------------------------------
| epoch 9 | time: 0.49s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0013, Accuracy: 0.9831
Batch [100/152], Loss: 0.0013, Accuracy: 0.9831
Batch [150/152], Loss: 0.0014, Accuracy: 0.9825
---------------------------------------------------------------------
| epoch 10 | time: 0.54s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------

、預測

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# ex_text_str = "隨便播放一首專輯閣樓里的佛里的歌"
ex_text_str = "還有雙鴨山到淮陰的汽車票嗎13號的"model = model.to("cpu")print("該文本的類別是:%s" %label_name[predict(ex_text_str, text_pipeline)])
該文本的類別是:Travel-Query

總結:?

  1. ?語料庫(原始文本)?:

    來源包括維基百科、網頁文本、新聞資訊及內部文本。
  2. ?文本清洗?:

    清洗原始文本,包括去除標點符號和特殊字符。該流程主要用于將原始文本數據轉化為可用于模型訓練的數值化向量,再通過深度學習模型進行文本分類。
    • ?分詞?:

      使用jieba分詞工具對清洗后的文本進行分詞處理。
    • ?建模?:

      采用不同的模型進行文本建模,包括循環神經網絡(RNN)、卷積神經網絡(CNN)、門控循環單元(GRU)和長短期記憶網絡(LSTM)。
    • ?文本向量化?:

      將分詞后的文本轉換為向量表示,方法包括TF-IDF和Word2vec。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/79197.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/79197.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/79197.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

010302-oss_反向代理_負載均衡-web擴展2-基礎入門-網絡安全

文章目錄 1 OSS1.1 什么是 OSS 存儲&#xff1f;1.2 OSS 核心功能1.3 OSS 的優勢1.4 典型使用場景1.5 如何接入 OSS&#xff1f;1.6 注意事項1.7 cloudreve實戰演示1.7.1 配置cloudreve連接阿里云oss1.7.2 常見錯誤1.7.3 安全測試影響 2 反向代理2.1 正向代理和反向代理2.2 演示…

【 Node.js】 Node.js安裝

下載 下載 | Node.js 中文網https://nodejs.cn/download/ 安裝 雙擊安裝包 點擊Next 勾選使用許可協議&#xff0c;點擊Next 選擇安裝位置 點擊Next 點擊Next 點擊Install 點擊Finish 完成安裝 添加環境變量 編輯【系統變量】下的變量【Path】添加Node.js的安裝路徑--如果…

Python基本語法(自定義函數)

自定義函數 Python語言沒有子程序&#xff0c;只有自定義函數&#xff0c;目的是方便我們重復使用相同的一 段程序。將常用的代碼塊定義為一個函數&#xff0c;以后想實現相同的操作時&#xff0c;只要調用函數名就可以了&#xff0c;而不需要重復輸入所有的語句。 函數的定義…

OpenGL-ES 學習(11) ---- EGL

目錄 EGL 介紹EGL 類型和初始化EGL初始化方法獲取 eglDisplay初始化 EGL選擇 Config構造 Surface構造 Context開始繪制 EGL Demo EGL 介紹 OpenGL-ES 是一個操作GPU的圖像API標準&#xff0c;它通過驅動向 GPU 發送相關圖形指令&#xff0c;控制圖形渲染管線狀態機的運行狀態&…

極簡5G專網解決方案

極簡5G專網解決方案 利用便攜式即插即用私有 5G 網絡提升您的智能創新。為您的企業提供無縫、安全且可擴展的 5G 解決方案。 提供極簡5G專網解決方案 Mantiswave Network Private Limited 提供全面的 5G 專用網絡解決方案&#xff0c;以滿足您企業的獨特需求。我們創新的“…

html:table表格

表格代碼示例&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body><!-- 標準表格。 --><table border"5"cellspacing&qu…

tkinter 電子時鐘 實現時間日期 可實現透明

以下是一個使用Tkinter模塊創建一個簡單的電子時鐘并顯示時間和日期的示例代碼&#xff1a; import tkinter as tk import time# 創建主窗口 root tk.Tk() root.overrideredirect(True) # 隱藏標題欄 root.attributes(-alpha, 0.7) # 設置透明度# 顯示時間的標簽 time_labe…

【報錯問題】 macOS 的安全策略(Gatekeeper)阻止了未簽名的原生模塊(bcrypt_lib.node)加載

這個錯誤是由于 macOS 的安全策略&#xff08;Gatekeeper&#xff09;阻止了未簽名的原生模塊&#xff08;bcrypt_lib.node&#xff09;加載 導致的。以下是具體解決方案&#xff1a; 1. 臨時允許加載未簽名模塊&#xff08;推薦先嘗試&#xff09; 在終端運行以下命令&#x…

AI實現制作logo的網站添加可選顏色模板

1.效果圖 LogoPalette.jsx import React, {useState} from react import HeadingDescription from ./HeadingDescription import Lookup from /app/_data/Lookup import Colors from /app/_data/Colors function LogoPalette({onHandleInputChange}) { const [selectOptio…

云原生后端架構的挑戰與應對策略

??個人主頁??:慌ZHANG-CSDN博客 ????期待您的關注 ???? 隨著云計算、容器化以及微服務等技術的快速發展,云原生架構已經成為現代軟件開發和運維的主流趨勢。企業通過構建云原生后端系統,能夠實現靈活的資源管理、快速的應用迭代和高效的系統擴展。然而,盡管云原…

【C++】模板為什么要extern?

模板為什么要extern&#xff1f; 在 C 中&#xff0c;多個編譯單元使用同一個模板時&#xff0c;是否可以不使用 extern 取決于模板的實例化方式&#xff08;隱式或顯式&#xff09;&#xff0c;以及你對編譯時間和二進制體積的容忍度。 1. 隱式實例化&#xff1a;可以不用 ex…

中小企業MES系統數據庫設計

版本&#xff1a;V1.0 日期&#xff1a;2025年5月2日 一、數據庫架構概覽 1.1 數據庫選型 數據類型數據庫類型技術選型用途時序數據&#xff08;傳感器讀數&#xff09;時序數據庫TimescaleDB存儲設備實時監控數據結構化業務數據關系型數據庫PostgreSQL工單、質量、設備等核心…

VUE篇之樹形特殊篇

根節點是level:1, level3及其子節點有關聯&#xff0c;但是和level2和他下面的子節點沒有關聯 思路&#xff1a;采用守護風琴效果&#xff0c;遍歷出level1和level2級節點&#xff0c;后面level3的節點&#xff0c;采用樹形結構進行關聯 <template><div :class"…

洛圣電玩系列部署實錄:一次自己從頭跑通的搭建過程

寫這篇文章不是為了“教大家怎么一步步安裝”&#xff0c;而是想把我自己完整跑通洛圣電玩整個平臺的經歷復盤下來。因為哪怕你找到了所謂的全套源碼資源&#xff0c;如果沒人告訴你這些資源之間是怎么連起來的&#xff0c;你依舊是一臉懵逼。 我拿到的是什么版本&#xff1f; …

騰訊云web服務器配置步驟是什么?web服務器有什么用途?

騰訊云web服務器配置步驟是什么?web服務器有什么用途&#xff1f; Web服務器配置步驟&#xff08;以常見環境為例&#xff09; 1. 安裝Web服務器軟件 Linux系統&#xff08;如Ubuntu&#xff09; Apache: sudo apt update sudo apt install apache2 Nginx: sudo apt install…

第37課 繪制原理圖——放置離頁連接符

什么是離頁連接符&#xff1f; 前邊我們介紹了網絡標簽&#xff08;Net Lable&#xff09;&#xff0c;可以讓兩根導線“隔空相連”&#xff0c;使原理圖更加清爽簡潔。 但是網絡標簽的使用也具有一定的局限性&#xff0c;對于兩張不同Sheet上的導線&#xff0c;網絡標簽就不…

Win下的Kafka安裝配置

一、準備工作&#xff08;可以不做&#xff0c;畢竟最新版kafka也不需要zk&#xff09; 1、Windows下安裝Zookeeper &#xff08;1&#xff09;官網下載Zookeeper 官網下載地址 &#xff08;2&#xff09;解壓Zookeeper安裝包到指定目錄C:\DevelopApp\zookeeper\apache-zoo…

前端Vue3 + 后端Spring Boot,前端取消請求后端處理邏輯分析

在 Vue3 Spring Boot 的技術棧下&#xff0c;前端取消請求后&#xff0c;后端是否繼續執行業務邏輯的答案仍然是 取決于請求處理的階段 和 Spring Boot 的實現方式。以下是結合具體技術的詳細分析&#xff1a; 1. 請求未到達 Spring Boot 場景&#xff1a;前端通過 AbortContr…

【藍橋杯省賽真題58】Scratch畫臺扇 藍橋杯scratch圖形化編程 中小學生藍橋杯省賽真題講解

目錄 scratch畫臺扇 一、題目要求 編程實現 二、案例分析 1、角色分析 2、背景分析 3、前期準備 三、解題思路 四、程序編寫 五、考點分析 六、推薦資料 1、scratch資料 2、python資料 3、C++資料 scratch畫臺扇 第十五屆青少年藍橋杯scratch編程省賽真題解析 …

GPT-4o 圖像生成與八個示例指南

什么是GPT-4o圖像生成&#xff1f; 簡單來說&#xff0c;GPT-4o圖像生成是集成在ChatGPT內部的一項功能。用戶可以直接在對話中&#xff0c;通過文本描述&#xff08;Prompt&#xff09;來創建、編輯和調整圖像。這與之前的圖像生成工具相比&#xff0c;體驗更流暢、交互性更強…