第N5周:Pytorch文本分類入門

  • ????????🍨 本文為🔗365天深度學習訓練營中的學習記錄博客
  • ? ? ? ? 🍖 原作者:K同學啊

一、前期準備

1.加載數據
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")
#忽略警告信息#win10系統,調用GPU運行
#device = torch.device("cuda" if torch.cuda.is_available()else "cpu")
#devicedevice = torch.device("cpu")
device
device(type='cpu')
from torchtext.datasets import AG_NEWStrain_iter = list(AG_NEWS(split='train')) #加載 AG_News 數據集
num_class = len(set([label for (label, text) in train_iter]))
2.構建詞典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratortokenizer = get_tokenizer('basic_english') # 返回分詞器函數,訓練營內“get tokenizer函數詳解def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) #設置默認索引,如果找不到單詞,則會選擇默認索引
vocab(['here','is','an','example'])

?[475, 21, 30, 5297]

text_pipeline= lambda x:vocab(tokenizer(x))
label_pipeline=lambda x:int(x)-1text_pipeline('here is the an example')

?[475, 21, 2, 30, 5297]

label_pipeline('10')

9?

3.生成數據批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list,text_list,offsets =[],[],[0]for(_label, _text) in batch:#標簽列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)#偏移量,即語句的總詞匯量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  #返回維度dim中輸入元素的累計和return label_list.to(device),text_list.to(device),offsets.to(device)#數據加載器
dataloader =DataLoader(train_iter,batch_size=8,shuffle =False,collate_fn=collate_batch)

二、準備模型

1.定義模型
from torch import nnclass TextclassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextclassificationModel,self).__init__()self.embedding =nn.EmbeddingBag(vocab_size, #詞典大小embed_dim,  #嵌入的維度sparse=False) #self.fc =nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange =0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self,text, offsets):embedded =self.embedding(text,offsets)return self.fc(embedded)
2.定義實例
num_class = len(set([label for(label,text)in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextclassificationModel(vocab_size,em_size,num_class).to(device)
3.定義訓練函數和評估函數
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切換為測試模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 計算loss值# 記錄測試數據total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

三、訓練模型

1.拆分數據集并運行模型
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切換為測試模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 計算loss值# 記錄測試數據total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count
| epoch 1 |  500/1782 batches | train_acc 0.904 train_loss 0.00450
| epoch 1 | 1000/1782 batches | train_acc 0.903 train_loss 0.00455
| epoch 1 | 1500/1782 batches | train_acc 0.904 train_loss 0.00443
---------------------------------------------------------------------
| epoch 1 | time:11.72s | valid_acc 0.901 valid_loss 0.005
---------------------------------------------------------------------
| epoch 2 |  500/1782 batches | train_acc 0.918 train_loss 0.00379
| epoch 2 | 1000/1782 batches | train_acc 0.920 train_loss 0.00377
| epoch 2 | 1500/1782 batches | train_acc 0.913 train_loss 0.00399
---------------------------------------------------------------------
| epoch 2 | time:11.52s | valid_acc 0.907 valid_loss 0.005
---------------------------------------------------------------------
| epoch 3 |  500/1782 batches | train_acc 0.930 train_loss 0.00323
| epoch 3 | 1000/1782 batches | train_acc 0.925 train_loss 0.00345
| epoch 3 | 1500/1782 batches | train_acc 0.925 train_loss 0.00350
---------------------------------------------------------------------
| epoch 3 | time:11.77s | valid_acc 0.915 valid_loss 0.004
---------------------------------------------------------------------
| epoch 4 |  500/1782 batches | train_acc 0.937 train_loss 0.00294
| epoch 4 | 1000/1782 batches | train_acc 0.931 train_loss 0.00317
| epoch 4 | 1500/1782 batches | train_acc 0.927 train_loss 0.00332
---------------------------------------------------------------------
| epoch 4 | time:11.81s | valid_acc 0.914 valid_loss 0.004
---------------------------------------------------------------------
| epoch 5 |  500/1782 batches | train_acc 0.951 train_loss 0.00243
| epoch 5 | 1000/1782 batches | train_acc 0.950 train_loss 0.00243
| epoch 5 | 1500/1782 batches | train_acc 0.949 train_loss 0.00245
---------------------------------------------------------------------
| epoch 5 | time:11.94s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 6 |  500/1782 batches | train_acc 0.951 train_loss 0.00236
| epoch 6 | 1000/1782 batches | train_acc 0.951 train_loss 0.00241
| epoch 6 | 1500/1782 batches | train_acc 0.951 train_loss 0.00241
---------------------------------------------------------------------
| epoch 6 | time:11.69s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
| epoch 7 |  500/1782 batches | train_acc 0.952 train_loss 0.00233
| epoch 7 | 1000/1782 batches | train_acc 0.952 train_loss 0.00236
| epoch 7 | 1500/1782 batches | train_acc 0.952 train_loss 0.00235
---------------------------------------------------------------------
| epoch 7 | time:11.88s | valid_acc 0.920 valid_loss 0.004
---------------------------------------------------------------------
| epoch 8 |  500/1782 batches | train_acc 0.953 train_loss 0.00233
| epoch 8 | 1000/1782 batches | train_acc 0.954 train_loss 0.00226
| epoch 8 | 1500/1782 batches | train_acc 0.953 train_loss 0.00229
---------------------------------------------------------------------
| epoch 8 | time:11.92s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 9 |  500/1782 batches | train_acc 0.956 train_loss 0.00223
| epoch 9 | 1000/1782 batches | train_acc 0.955 train_loss 0.00219
| epoch 9 | 1500/1782 batches | train_acc 0.955 train_loss 0.00223
---------------------------------------------------------------------
| epoch 9 | time:11.78s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
| epoch 10 |  500/1782 batches | train_acc 0.955 train_loss 0.00226
| epoch 10 | 1000/1782 batches | train_acc 0.954 train_loss 0.00223
| epoch 10 | 1500/1782 batches | train_acc 0.955 train_loss 0.00221
---------------------------------------------------------------------
| epoch 10 | time:11.82s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
2.使用測試數據集評估模型?
print('checking the results of test dataset.')
test_acc,test_loss = evaluate(test_dataloader,model, criterion)
print('test accuracy{:8.3f}'.format(test_acc))

四、學習心得

? ? ? ?本周額外安裝了 portalocker 庫,并且下載了AG_News數據集,并TextClassificationModel模型,首先對文本進行嵌入,然后對句子嵌入之后的結果進行均值聚合,從而最終實現了文本分類的任務。在訓練過程出現一些問題得到有效解決。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/86835.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/86835.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/86835.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

uniappx 安卓app項目本地打包運行,騰訊地圖報錯:‘鑒權失敗,請檢查你的key‘

根目錄下添加 AndroidManifest.xml 文件&#xff0c; <application><meta-data android:name"TencentMapSDK" android:value"騰訊地圖申請的key" /> </application> manifest.json 文件中添加&#xff1a; "app": {"…

【向上教育】結構化面試開口秘籍.pdf

向 上 教 育 XI A N G S H A N G E D U C A T I O N 結構化 面試 開口秘笈 目 錄 第一章 自我認知類 ........................................................................................................................... 2 第二章 工作關系處理類 .......…

Webpack 熱更新(HMR)原理詳解

&#x1f525; Webpack 熱更新&#xff08;HMR&#xff09;原理詳解 &#x1f4cc; 本文適用于 Vue、React 等使用 Webpack 的項目開發者&#xff0c;適配 Vue CLI / 自定義 Webpack 項目。 &#x1f3af; 一、什么是 HMR&#xff1f; Hot Module Replacement 是 Webpack 提供的…

MySQL索引完全指南

一、索引是什么&#xff1f;為什么這么重要&#xff1f; 索引就像字典的目錄 想象一下&#xff0c;你要在一本1000頁的字典里找"程序員"這個詞&#xff0c;你會怎么做&#xff1f; 沒有目錄&#xff1a;從第1頁開始一頁一頁翻&#xff0c;可能要翻500頁才能找到有…

學習使用dotnet-dump工具分析.net內存轉儲文件(2)

運行ShenNiusModularity項目&#xff0c;使用createdump工具dump完整的進程內存映射文件&#xff0c;然后運行dotnet-dump analyze命令加載dump文件。 ??可以先使用dumpheap命令顯示有關垃圾回收堆的信息和有關對象的收集統計信息。dumpheap支持多類參數&#xff08;如下所示…

Oracle BIEE 交互示例(一)同一分析內

Oracle BIEE 交互示例(一)同一分析內 1 示例背景2 實踐目標3 實操步驟3.1 創建數據集3.1.1 TEST_TABLE3.1.2 保存名字為【01 TEST_TABLE】3.2 創建分析3.2.1 創建列3.2.2 創建視圖3.2.2.1 數據透視表3.2.2.2 圖形3.2.2.3 表3.3 設置交互4 結果示例1 示例背景 版本:OBIEE 12…

使用API有效率地管理Dynadot域名,出售賬戶中的域名

關于Dynadot Dynadot是通過ICANN認證的域名注冊商&#xff0c;自2002年成立以來&#xff0c;服務于全球108個國家和地區的客戶&#xff0c;為數以萬計的客戶提供簡潔&#xff0c;優惠&#xff0c;安全的域名注冊以及管理服務。 Dynadot平臺操作教程索引&#xff08;包括域名郵…

Vite 打包原理詳解 + Webpack 對比

&#x1f680; Vite 打包原理詳解 Webpack 對比 &#x1f44b; 本文適合&#xff1a;Vite 使用者、Vue/React 工程師、希望搞清楚打包流程及與 Webpack 區別的開發者 &#x1f310; 技術背景&#xff1a;Vite 采用 ES Modules 原生瀏覽器能力驅動開發體驗&#xff0c;Webpack…

區塊鏈RWA(Real World Assets)系統開發全棧技術架構與落地實踐指南

一、技術架構設計&#xff1a;分層架構與模塊協同 1. 核心區塊鏈層 區塊鏈選型策略&#xff1a; 公鏈&#xff1a;以太坊主網&#xff08;安全性高&#xff0c;DeFi生態完備&#xff09; Polygon CDK&#xff08;Layer2定制化合規鏈&#xff0c;Gas費低至$0.003&#xff09;…

GBDT:梯度提升決策樹——集成學習中的預測利器

核心定位&#xff1a;一種通過串行集成弱學習器&#xff08;決策樹&#xff09;、以梯度下降方式逐步逼近目標函數的機器學習算法&#xff0c;在結構化數據預測任務中表現出色。 本文由「大千AI助手」原創發布&#xff0c;專注用真話講AI&#xff0c;回歸技術本質。拒絕神話或妖…

Redis持久化機制深度解析:RDB與AOF全面指南

摘要 本文深入剖析Redis的持久化機制&#xff0c;全面講解RDB和AOF兩種持久化方式的原理、配置與應用場景。通過詳細的操作步驟和原理分析&#xff0c;您將掌握如何配置Redis持久化策略&#xff0c;確保數據安全性與性能平衡。文章包含思維導圖概覽、命令實操演示、核心原理圖…

CentOS7升級openssh10.0p2和openssl3.5.0詳細操作步驟

背景 近期漏洞掃描時&#xff0c;發現有很多關于openssh的相關高危漏洞&#xff0c;因此需要升級openssh的版本 升級步驟 由于openssh和openssl的版本是需要相匹配的&#xff0c;這次計劃將openssh升級至10.0p2版本&#xff0c;將openssl升級至3.5.0版本&#xff0c;都是目前…

fishbot隨身系統安裝nvidia顯卡驅動

小魚的fishbot是已經配置好的ubuntu22.04,我聽說在預先配置系統時需要勾選安裝第三方圖形化軟件&#xff0c;不然直接安裝會有進不去圖形化界面的風險&#xff0c;若沒有勾選&#xff0c;建議使用其他安裝方法&#xff0c;比如禁用系統自帶的驅動那套安裝流程 1.打開設置->關…

學習昇騰開發的第十天--ffmpeg推拉流

1、FFmpeg推流 注意&#xff1a;在推流之前先運行rtsp-simple-server&#xff08;mediamtx&#xff09; ./mediamtx 1.1 UDP推流 ffmpeg -re -i input.mp4 -c copy -f rtsp rtsp://127.0.0.1:8554/stream 1.2 TCP推流 ffmpeg -re -i input.mp4 -c copy -rtsp_transport t…

成為一名月薪 2 萬的 web 安全工程師需要掌握哪些技能??

現在 web 安全工程師比較火&#xff0c;崗位比較稀缺&#xff0c;現在除了一些大公司對學歷要求嚴格&#xff0c;其余公司看中的大部分是能力。 有個親戚的兒子已經工作 2 年了……當初也是因為其他的行業要求比較高&#xff0c;所以才選擇的 web 安全方向。 資料免費分享給你…

Pytorch8實現CNN卷積神經網絡

CNN卷積神經網絡 本章提供一個對CNN卷積網絡的快速實現 全連接網絡 VS 卷積網絡 全連接神經網絡之所以不太適合圖像識別任務&#xff0c;主要有以下幾個方面的問題&#xff1a; 參數數量太多 考慮一個輸入10001000像素的圖片(一百萬像素&#xff0c;現在已經不能算大圖了)&…

平地起高樓: 環境搭建

技術選型 本小冊是采用純前端的技術棧模擬實現小程序架構的系列文章&#xff0c;所以主要以前端技術棧為主&#xff0c;但是為了模擬一個App應用的效果&#xff0c;以及小程序包下載管理流程的實現&#xff0c;我們還是需要搭建一個基礎的App應用。這里我們將選擇 Tauri2.0 來…

langgraph學習2 - MCP編程

3中通信方式&#xff1a; 目前sse用的很少 3.開發mcp框架 主流框架2個&#xff1a; MCP skd 官方 Fast Mcp V2 &#xff0c;&#xff08;V1捐給MCP 官方&#xff09; 大模型如何識別用哪個tools&#xff0c; 以及如何使用tools&#xff1a;

CSS 與 JavaScript 加載優化

&#x1f4c4; CSS 與 JavaScript 加載優化指南&#xff1a;位置、阻塞與性能 讓你的網頁飛起來&#xff01;&#x1f680; 本文詳細解析 CSS 和 JavaScript 標簽的放置位置如何影響頁面性能&#xff0c;涵蓋阻塞原理、瀏覽器機制和最佳實踐。掌握這些知識可顯著提升用戶體驗…

WSL安裝發行版上安裝podman

WSL安裝發行版上安裝podman 1.WSL拉取發行版1.1 拉取2.2.修改系統拉取的鏡像&#xff0c;可以加速軟件包的更新 2.podman安裝2.1.安裝podman 容器工具2.2.配置podman的鏡像倉庫2.3.拉取n8n鏡像并創建容器 本文在windows11上&#xff0c;使用WSL拉取并創建ubuntu24.04虛擬機&…