深度學習打卡第N6周:中文文本分類-Pytorch實現

  • 🍨?本文為🔗365天深度學習訓練營中的學習記錄博客
  • 🍖?原作者:K同學啊

一、準備工作

數據格式:

import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")device = torch.device("cuda" if torch.cuda.is_available else "cpu")import pandas as pd# CSV 格式通常為 無表頭(header=None),以制表符(sep='\t')分隔
train_data = pd.read_csv('./data/train.csv',sep='\t',header=None)
train_data.head()

# 構造數據集迭代器
def custom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、數據預處理

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分詞方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])label_name = list(set(train_data[1].values[:]))text_pipeline = lambda x:vocab(tokenizer(x))
label_pipeline = lambda x:label_name.index(x)

三、模型搭建

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,embed_dim)self.fc = nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange)self.fc.weight.data.uniform_(-initrange,initrange)self.fc.bias.data.zero_()def forward(self,text,offsets):embedded = self.embedding(text,offsets)return self.fc(embedded)
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)
model

import timedef train(dataloader):model.train()total_acc,train_loss,total_count = 0,0,0log_interval = 50start_time = time.time()for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)optimizer.zero_grad()loss = criterion(predicted_label,label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),0.1) # 梯度裁剪optimizer.step()total_acc += (predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()*label.size(0)total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc,test_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)loss = criterion(predicted_label,label)total_acc += (predicted_label.argmax(1)==label).sum().item()test_loss += loss.item()*label.size(0)total_count += label.size(0)return total_acc/total_count,test_loss/total_count

四、訓練模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset# 超參數
EPOCHS = 10
LR = 5
BATCH_SIZE = 64criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = Nonetrain_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)num_train = int(len(train_dataset)*0.8)
split_train,split_valid = random_split(train_dataset,[num_train,len(train_dataset)-num_train])train_dataloader = DataLoader(split_train,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-' * 69)
def predict(text):with torch.no_grad():text = torch.tensor(text_pipeline(text)).to(device)output = model(text,torch.tensor([0]).to(device))return output.argmax(1).item()
# ex_text_str = "還有南昌到哈爾濱西的火車票嗎?"
ex_text_str = "我想聽TWICE的新曲"
print("該文本的類別是:%s" %label_name[predict(ex_text_str)])

總結

本次學習對中文文本實現了分類,主要代碼和N1周基本一致。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98941.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98941.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98941.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【代碼隨想錄day 24】 力扣 90. 集合II

視頻講解&#xff1a;https://www.bilibili.com/video/BV1vm4y1F71J/?vd_sourcea935eaede74a204ec74fd041b917810c 文檔講解&#xff1a;https://programmercarl.com/0090.%E5%AD%90%E9%9B%86II.html#%E6%80%9D%E8%B7%AF 力扣題目&#xff1a;https://leetcode.cn/problems/su…

.NET 6 文件下載

.NET 6 API中實現文件的下載。創建HttpHeaderConstant用于指定http頭。public sealed class HttpHeaderConstant{public const string RESPONSE_HEADER_CONTENTTYPE_STREAM "application/octet-stream";public const string RESPONSE_HEADER_NAME_FILENAME "f…

[數據結構——lesson6.棧]

目錄 引言 1.棧的概念和結構 棧的核心概念 棧的結構 2.棧的實現 2.1棧的實現方式 2.2棧的功能 2.3棧的聲明 1.順序棧 2。鏈式棧 2.4棧的功能實現 1.棧的初始化 2.判斷棧是否為空 3.返回棧頂元素 4.返回棧的大小 5.元素入棧 6.元素出棧 7.打印棧的元素 8.銷毀…

華為HICE云計算的含金量高嗎?

在數字時代的今天&#xff0c;云計算技術證飛速的發展成為企業數字化轉型的重要支撐。而華為作為領先的通信和信息技術公司&#xff0c;推出的HCIE云計算認證備受關注。接下來就來說說華為HCIE云計算認證的含金量到底有多高。HCIE認證被認為是華為認證中的最高等級&#xff0c;…

OSPF協議原理講解和實際配置(華為/思科)

OSPF&#xff08;open shorest path first&#xff0c;開放最短路徑優先&#xff09;是一種動態的&#xff0c;基于鏈路狀態的動態路由協議&#xff0c;廣泛的應用在企業網絡中&#xff0c;通過維護網絡拓撲信息&#xff0c;利用 Dijkstra 算法實現最短路徑&#xff0c;實現高效…

【開題答辯全過程】以 《黃帝內經》問答系統為例,包含答辯的問題和答案

個人簡介一名14年經驗的資深畢設內行人&#xff0c;語言擅長Java、php、微信小程序、Python、Golang、安卓Android等開發項目包括大數據、深度學習、網站、小程序、安卓、算法。平常會做一些項目定制化開發、代碼講解、答辯教學、文檔編寫、也懂一些降重方面的技巧。感謝大家的…

npm : 無法加載文件 C:\Program Files\nodejs\npm.ps1,因為在此系統上禁止運行腳

這個錯誤是由于 PowerShell 的執行策略限制&#xff0c;導致無法運行腳本。你可以通過以下步驟解決這個問題&#xff1a; 1. 查看當前的執行策略 打開 PowerShell&#xff0c;以管理員身份運行&#xff0c;輸入以下命令查看當前的執行策略&#xff1a; Get-ExecutionPolicy如果…

macOS蘋果電腦運行向日葵遠程控制軟件閃退

文章目錄問題原因分析修復附錄向日葵字太小按Ctrl鍵會彈出開始菜單的問題問題 向日葵是一款遠程控制的應用&#xff0c;在macOS下也能運行&#xff0c; 本來用的好好的&#xff0c;有一天升級后突然就運行不起來了&#xff0c;一點開能顯示幾秒首界面&#xff0c;立馬就自動退…

Linux dma-buf 框架原理、實現與應用詳解

1. 背景與意義 1.1 異構系統與緩沖區共享的挑戰 在現代 SoC、嵌入式、圖形和多媒體系統中&#xff0c;CPU、GPU、VPU、ISP、DMA 控制器等多個硬件單元需要高效地共享和傳遞大塊數據&#xff08;如圖像幀、視頻流、AI 張量等&#xff09;。如果每個設備都維護獨立的緩沖區&…

Scikit-learn Python機器學習 - 分類算法 - 樸素貝葉斯

鋒哥原創的Scikit-learn Python機器學習視頻教程&#xff1a; https://www.bilibili.com/video/BV11reUzEEPH 課程介紹 ? 本課程主要講解基于Scikit-learn的Python機器學習知識&#xff0c;包括機器學習概述&#xff0c;特征工程(數據集&#xff0c;特征抽取&#xff0c;特…

如何免費股票數據API(第13期):滬深A股《最新分時交易》數據獲取大全:附Python、Java等多語言實戰教程與接口文檔說明

在金融科技迅猛發展的今天&#xff0c;股票量化分析以其嚴謹的科學性和強大的系統性&#xff0c;正日益成為投資領域的主流方法論。任何卓越的量化模型的誕生&#xff0c;都離不開全面、精準、及時的數據支撐。無論是躍動著的實時交易數據、沉淀了歷史規律的K線走勢&#xff0c…

國標GB28181視頻EasyGBS視頻監控平臺:一網聯全城,交通道路可視化、視頻巡檢、應急指揮“三合一”。

一、方案背景?人車暴漲&#xff0c;路口告急&#xff1a;高峰堵、事故慢、取證難&#xff0c;老辦法已拖不動城市交通。破局之道&#xff0c;先看攝像頭——EasyGBS 嚴格遵循 GB28181 國標&#xff0c;一站式完成直播、存儲、檢索、轉碼&#xff0c;把萬千路口秒級搬上云端&am…

單元測試(白盒測試方法)

一、單元測試1.單元測試是對軟件的基本組成單元進行的測試&#xff0c;如函數、類或類的方法。單元測試是對軟件的最小可測試單元&#xff08;即可獨立編譯或匯編的程序模塊&#xff09;進行的測試活動&#xff0c;也稱為模塊測試二、白盒測試方法實例代碼public static int te…

2010-2022 同等學力申碩國考:軟件工程簡答題真題匯總

2010年簡答題 給出數據流圖的定義&#xff0c;并舉例說明數據流圖的四個基本構成成份。 數據流圖&#xff08;Data Flow Diagram, DFD&#xff09;是一種用于描述系統中數據流動和處理過程的圖形工具。它通過直觀的方式展示了系統的輸入數據如何經過一系列處理變換為輸出數據&a…

海外盲盒APP開發:如何用技術重構“驚喜經濟”

當盲盒的神秘感遇上技術的確定性&#xff0c;一場關于消費體驗的革命正在海外市場悄然發生。從概率算法的公平性到AR虛擬開箱的沉浸感&#xff0c;從跨境物流的實時追蹤到多語言支持的無縫切換&#xff0c;海外盲盒APP的開發是一場技術、設計與商業邏輯的深度融合。概率算法&am…

Aosp13 手機sim卡信號格顯示修改

工作中&#xff0c;客戶要求對信號格顯示偏弱不夠友好為由&#xff0c;提出修改&#xff0c;要求使其顯示信號強一些。在此記錄 一問題&#xff1a;修改系統sim卡顯示的信號格&#xff0c;在設備其他配置不變的情況下&#xff0c;使其信號格顯示比原有的要優秀二 …

硬件開發2-匯編2(ARMv7-A)- 裸機開發

一、指令1、b&#xff08;Branch&#xff09;原型&#xff1a;B<c> <label>作用&#xff1a;實現無條件跳轉&#xff0c;常用于不返回的跳轉場景特點&#xff1a;僅跳轉到目標地址&#xff0c;不保存返回地址示例&#xff1a;b reset ;跳轉到reset標號處執…

清源 SCA 社區版更新(V4.2.0)|漏洞前置感知、精準修復、合規清晰,筑牢軟件供應鏈安全防線!

隨著數字化進程加速&#xff0c;軟件供應鏈安全威脅日益復雜&#xff0c;公開漏洞響應滯后、0day 攻擊防不勝防、組件升級編譯失敗、安全與合規風險混雜......這些痛點讓企業安全團隊、運維人員及研發團隊疲于應對。自 2025 年 7 月 1 日安勢清源 SCA 社區版首次正式發布以及在…

氚燃料增殖里程碑:MIT新型BABY包層技術實驗驗證

● 導語 5月20日&#xff0c;麻省理工學院&#xff08;MIT&#xff09;發文稱&#xff0c;BABY實驗首次獲取了氚在裝置內增殖的實測數據&#xff0c;驗證了核心模型&#xff0c;并為未來核聚變電廠的燃料自循環奠定了重要基礎。 原文&#x1f447;&#x1f3fb; https://m…

python+springboot+uniapp微信小程序題庫系統 在線答題 題目分類 錯題本管理 學習記錄查詢系統

目錄技術棧介紹具體實現截圖系統設計研究方法&#xff1a;設計步驟設計流程核心代碼部分展示研究方法詳細視頻演示試驗方案論文大綱源碼獲取/詳細視頻演示技術棧介紹 Django-SpringBoot-php-Node.js-flask 本課題的研究方法和研究步驟基本合理&#xff0c;難度適中&#xff0…