深度學習筆記39_Pytorch文本分類入門

  • 🍨 本文為🔗365天深度學習訓練營?中的學習記錄博客
  • 🍖 原作者:K同學啊 | 接輔導、項目定制

一、我的環境

1.語言環境:Python 3.8

2.編譯器:Pycharm

3.深度學習環境:

  • torch==1.12.1+cu113
  • torchvision==0.13.1+cu113

、導入數據

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')      # 加載 AG News 數據集

、構建詞典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratortokenizer  = get_tokenizer('basic_english') # 返回分詞器函數def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 設置默認索引,如果找不到單詞,則會選擇默認索引
print(vocab(['here', 'is', 'an', 'example']))

結果:?[475, 21, 30, 5297]

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
print(text_pipeline('here is the an example'))
結果:[475, 21, 2, 30, 5297]
print(label_pipeline('10'))
結果:10

生成數據批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 標簽列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即語句的總詞匯量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回維度dim中輸入元素的累計和return label_list.to(device), text_list.to(device), offsets.to(device)# 數據加載器
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

定義模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,   # 詞典大小embed_dim,    # 嵌入的維度sparse=False) # self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

定義實例

num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size     = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定義訓練函數與評估函數

import timedef train(dataloader):model.train()  # 切換為訓練模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time   = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                    # grad屬性歸零loss = criterion(predicted_label, label) # 計算網絡輸出和真實值之間的差距,label為真實值loss.backward()                          # 反向傳播optimizer.step()  # 每一步自動更新# 記錄acc與losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切換為測試模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 計算loss值# 記錄測試數據total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

?結果:

| epoch 1 | 500/1782 batches| train_acc 0.901 train_loss 0.00458
| epoch 1 | 1000/1782 batches| train_acc 0.905 train_loss 0.00438
| epoch 1 | 1500/1782 batches| train_acc 0.908 train_loss 0.00437
---------------------------------------------------------------------
| epoch 1 | time:6.30s |valid_acc 0.907 | valid_loss 0.004
---------------------------------------------------------------------
| epoch 2 | 500/1782 batches| train_acc 0.917 train_loss 0.00381
| epoch 2 | 1000/1782 batches| train_acc 0.917 train_loss 0.00383
| epoch 2 | 1500/1782 batches| train_acc 0.917 train_loss 0.00386
---------------------------------------------------------------------
| epoch 2 | time:6.26s |valid_acc 0.911 | valid_loss 0.004
---------------------------------------------------------------------
| epoch 3 | 500/1782 batches| train_acc 0.929 train_loss 0.00330
| epoch 3 | 1000/1782 batches| train_acc 0.927 train_loss 0.00340
| epoch 3 | 1500/1782 batches| train_acc 0.923 train_loss 0.00354
---------------------------------------------------------------------
| epoch 3 | time:6.21s |valid_acc 0.935 | valid_loss 0.003
---------------------------------------------------------------------
| epoch 4 | 500/1782 batches| train_acc 0.933 train_loss 0.00306
| epoch 4 | 1000/1782 batches| train_acc 0.932 train_loss 0.00311
| epoch 4 | 1500/1782 batches| train_acc 0.929 train_loss 0.00318
---------------------------------------------------------------------
| epoch 4 | time:6.22s |valid_acc 0.916 | valid_loss 0.003
---------------------------------------------------------------------
| epoch 5 | 500/1782 batches| train_acc 0.948 train_loss 0.00253
| epoch 5 | 1000/1782 batches| train_acc 0.949 train_loss 0.00242
| epoch 5 | 1500/1782 batches| train_acc 0.951 train_loss 0.00238
---------------------------------------------------------------------
| epoch 5 | time:6.23s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 6 | 500/1782 batches| train_acc 0.951 train_loss 0.00241
| epoch 6 | 1000/1782 batches| train_acc 0.952 train_loss 0.00236
| epoch 6 | 1500/1782 batches| train_acc 0.952 train_loss 0.00235
---------------------------------------------------------------------
| epoch 6 | time:6.26s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 7 | 500/1782 batches| train_acc 0.954 train_loss 0.00228
| epoch 7 | 1000/1782 batches| train_acc 0.951 train_loss 0.00238
| epoch 7 | 1500/1782 batches| train_acc 0.954 train_loss 0.00228
---------------------------------------------------------------------
| epoch 7 | time:6.26s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 8 | 500/1782 batches| train_acc 0.953 train_loss 0.00227
| epoch 8 | 1000/1782 batches| train_acc 0.955 train_loss 0.00224
| epoch 8 | 1500/1782 batches| train_acc 0.954 train_loss 0.00224
---------------------------------------------------------------------
| epoch 8 | time:6.32s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 9 | 500/1782 batches| train_acc 0.955 train_loss 0.00218
| epoch 9 | 1000/1782 batches| train_acc 0.953 train_loss 0.00227
| epoch 9 | 1500/1782 batches| train_acc 0.955 train_loss 0.00227
---------------------------------------------------------------------
| epoch 9 | time:6.24s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------
| epoch 10 | 500/1782 batches| train_acc 0.952 train_loss 0.00229
| epoch 10 | 1000/1782 batches| train_acc 0.955 train_loss 0.00220
| epoch 10 | 1500/1782 batches| train_acc 0.956 train_loss 0.00220
---------------------------------------------------------------------
| epoch 10 | time:6.29s |valid_acc 0.954 | valid_loss 0.002
---------------------------------------------------------------------

定義訓練函數與評估函數

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))
Checking the results of test dataset.
test accuracy    0.905

總結:?

  1. 預訓練詞向量:使用GloVe、FastText等預訓練詞向量能顯著提升性能

  2. 正則化:合理使用dropout、權重衰減等技術防止過擬合

  3. 超參數調優:學習率、批大小、隱藏層維度等對模型性能影響很大

  4. 遷移學習:對于小數據集,考慮使用BERT等預訓練模型進行微調

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76217.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76217.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76217.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二分查找-LeetCode

題目 給定一個 n 個元素有序的&#xff08;升序&#xff09;整型數組 nums 和一個目標值 target&#xff0c;寫一個函數搜索 nums 中的 target&#xff0c;如果目標值存在返回下標&#xff0c;否則返回 -1。 示例 1: 輸入: nums [-1,0,3,5,9,12], target 9 輸出: 4 解釋: …

從 Ext 到 F2FS,Linux 文件系統與存儲技術全面解析

與 Windows 和 macOS 操作系統不同&#xff0c;Linux 是由愛好者社區開發的大型開源項目。它的代碼始終可供那些想要做出貢獻的人使用&#xff0c;任何人都可以根據個人需求自由調整它&#xff0c;或在其基礎上創建自己的發行版本。這就是為什么 Linux 存在如此多的變體&#x…

leetcode:3210. 找出加密后的字符串(python3解法)

難度&#xff1a;簡單 給你一個字符串 s 和一個整數 k。請你使用以下算法加密字符串&#xff1a; 對于字符串 s 中的每個字符 c&#xff0c;用字符串中 c 后面的第 k 個字符替換 c&#xff08;以循環方式&#xff09;。 返回加密后的字符串。 示例 1&#xff1a; 輸入&#xff…

JVM詳解(曼波腦圖版)

(?ω?)&#xff89; 好噠&#xff01;曼波會用最可愛的比喻給小白同學講解JVM&#xff0c;準備好開啟奇妙旅程了嗎&#xff1f;(??????)? &#x1f4cc; 思維導圖 ━━━━━━━━━━━━━━━━━━━ &#x1f34e; JVM是什么&#xff1f;&#xff08;蘋果式比…

ZStack文檔DevOps平臺建設實踐

&#xff08;一&#xff09;前言 對于軟件產品而言&#xff0c;文檔是不可或缺的一環。文檔能幫助用戶快速了解并使用軟件&#xff0c;包括不限于特性概覽、用戶手冊、API手冊、安裝部署以及場景實踐教程等。由于軟件與文檔緊密耦合&#xff0c;面對業務的瞬息萬變以及軟件的飛…

Git創建分支操作指南

1. 創建新分支但不切換&#xff08;僅創建&#xff09; git branch <分支名>示例&#xff1a;創建一個名為 new-feature 的分支git branch new-feature2. 創建分支并立即切換到該分支 git checkout -b <分支名> # 傳統方式 # 或 git switch -c <分支名&g…

package.json 中的那些版本數字前面的符號是什么意思?

1. 語義化版本&#xff08;SemVer&#xff09; 語義化版本的格式是 MAJOR.MINOR.PATCH&#xff0c;其中&#xff1a; MAJOR&#xff1a;主版本號&#xff0c;表示不兼容的 API 修改。MINOR&#xff1a;次版本號&#xff0c;表示新增功能但保持向后兼容。PATCH&#xff1a;修訂號…

如何有效防止服務器被攻擊

首先&#xff0c;我們要明白服務器被攻擊的危害有多大。據不完全統計&#xff0c;每年因服務器遭受攻擊而導致的經濟損失高達數十億。這可不是一個小數目&#xff0c;就好比您辛苦積攢的財富&#xff0c;瞬間被人偷走了一大半。 要有效防止服務器被攻擊&#xff0c;第一步就是…

Chainlit 快速構建Python LLM應用程序

背景 chainlit 是一款簡單易用的Web UI goggle&#xff0c;它支持使用 Python 語言快速構建 LLM 應用程序&#xff0c;提供了豐富的功能&#xff0c;包括文本分析&#xff0c;情感分析等。 這里我們以官網openai提供的例子&#xff0c;快速的開發一個帶有UI的聊天界面&#xf…

華為OD機試真題——硬件產品銷售方案(2025A卷:100分)Java/python/JavaScript/C++/C語言/GO六種最佳實現

2025 A卷 100分 題型 本文涵蓋詳細的問題分析、解題思路、代碼實現、代碼詳解、測試用例以及綜合分析&#xff1b; 并提供Java、python、JavaScript、C、C語言、GO六種語言的最佳實現方式&#xff01; 2025華為OD真題目錄全流程解析/備考攻略/經驗分享 華為OD機試真題《硬件產品…

【數據結構_6】雙向鏈表的實現

一、實現MyDLinkedList&#xff08;雙向鏈表&#xff09; package LinkedList;public class MyDLinkedList {//首先我們要創建節點&#xff08;因為雙向鏈表和單向鏈表的節點不一樣&#xff01;&#xff01;&#xff09;static class Node{public String val;public Node prev…

做Data+AI的長期主義者,加速全球化戰略布局

在Data與AI深度融合的新紀元&#xff0c;唯有秉持長期主義方能真正釋放數智化的深層價值。2025年是人工智能從技術爆發轉向規模化落地的關鍵節點&#xff0c;也是標志著袋鼠云即將迎來十周年的重要里程碑。2025年4月16日&#xff0c;袋鼠云成功舉辦了“做DataAI的長期主義者——…

構建基于PHP和MySQL的解夢系統:設計與實現

引言 夢境解析一直是人類心理學和文化研究的重要領域。隨著互聯網技術的發展,構建一個在線的解夢系統能夠幫助更多人理解自己夢境的含義。本文將詳細介紹如何使用PHP和MySQL構建一個功能完整的解夢系統,包括系統架構設計、數據庫模型、核心功能實現以及優化策略。 本文源碼下…

【桌面】【系統應用】Samba共享文件夾

目錄 場景一&#xff1a;銀河麒麟桌面與銀河麒麟桌面之間共享文件夾 環境準備 實現目標 操作步驟 &#xff08;一&#xff09;配置主機A共享文件夾 1、環境準備 2、在主機A創建共享文件夾 3、設置共享文件密碼 &#xff08;二&#xff09;主機B訪問主機A 場景二&…

OpenCV 圖形API(37)圖像濾波-----分離過濾器函數sepFilter()

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 應用一個可分離的線性濾波器到一個矩陣&#xff08;圖像&#xff09;。 該函數對矩陣應用一個可分離的線性濾波器。也就是說&#xff0c;首先&a…

webpack理解與使用

一、背景 webpack的最初目標是實現前端工程的模塊化&#xff0c;旨在更高效的管理和維護項目中的每一個資源。 最早的時候&#xff0c;我們通過文件劃分的方式實現模塊化&#xff0c;也就是將每個功能及其相關狀態數據都放在一個JS文件中&#xff0c;約定每個文件就是一個獨立…

rac環境下,增加一個控制文件controlfile

先關閉節點二&#xff0c;在節點一上操作 1、查看控制文件個數和路徑 SQL> show parameter control 2、備份參數文件 SQL> create pfile/home/oracle/orcl.pfile20250417 from spfile; 3、修改控制文件參數 SQL> alter system set contr…

git安裝(windows)

通過網盤分享的文件&#xff1a;資料(1) 鏈接: https://pan.baidu.com/s/1MAenYzcQ436MlKbIYQidoQ 提取碼: evu6 點擊next 可修改安裝路徑 默認就行 一般從命令行調用&#xff0c;所以不用創建。 用vscode&#xff0c;所以這么選擇。

Spring Boot整合難點?AI一鍵生成全流程解決方案

在當今的軟件開發領域&#xff0c;Spring Boot 憑借其簡化開發流程、快速搭建項目的優勢&#xff0c;成為了眾多開發者的首選框架。然而&#xff0c;Spring Boot 的整合過程并非一帆風順&#xff0c;常常會遇到各種難點。而飛算 JavaAI 的出現&#xff0c;為解決這些問題提供了…

Python批量處理PDF圖片詳解(插入、壓縮、提取、替換、分頁、旋轉、刪除)

目錄 一、概述 二、 使用工具 三、Python 在 PDF 中插入圖片 3.1 插入圖片到現有PDF 3.2 插入圖片到新建PDF 3.3 批量插入多張圖片到PDF 四、Python 提取 PDF 圖片及其元數據 五、Python 替換 PDF 圖片 5.1 使用圖片替換圖片 5.2 使用文字替換圖片 六、Python 實現 …