【NLP入門系列四】評論文本分類入門案例

在這里插入圖片描述

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

博主簡介:努力學習的22級本科生一枚 🌟?;探索AI算法,C++,go語言的世界;在迷茫中尋找光芒?🌸
博客主頁:羊小豬~~-CSDN博客
內容簡介:這一篇是NLP的入門項目,以AG_NEW新聞數據為例。
🌸箴言🌸:去尋找理想的“天空“”之城
上一篇內容:【NLP入門系列三】NLP文本嵌入(以Embedding和EmbeddingBag為例)-CSDN博客
?💁??💁??💁??💁?: 如果在conda安裝環境,由于nlp的核心包是torchtext,所以如果把握不好就重新創建一虛擬環境(小編的“難忘”經歷)

文章目錄

    • 1、準備
      • 數據加載
      • 構建詞表
    • 2、生成數據批次和迭代器
    • 3、定義與模型
      • 模型定義
      • 創建模型
    • 4、創建訓練和評估函數
      • 訓練函數
      • 評估函數
      • 創建超參數
    • 5、模型訓練
    • 6、結果展示
    • 7、預測

🤔 思路

在這里插入圖片描述

1、準備

AG News 數據集(也叫 AG’s Corpus or AG News Dataset),這是一個廣泛用于自然語言處理(NLP)任務中的文本分類數據集


基本信息:

  • 全稱:AG News
  • 來源:來源于 AG’s corpus,由 A. Godin 在 2005 年構建。
  • 用途:主要用于短文本多類別分類任務
  • 語言:英文
  • 類別數:4 類新聞主題
  • 訓練樣本數:120,000 條
  • 測試樣本數:7,600 條

類別標簽(共 4 類)

標簽含義
1World (世界)
2Sports (體育)
3Business (商業)
4Science and Technology (科技)

數據加載

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader 
import torchtext 
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 檢查設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
# 加載本地數據
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")# 合并標題和描述數據
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]# 查看數據格式
train_df
Class IndexTitleDescriptiontext
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...Wall St. Bears Claw Back Into the Black (Reute...
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...Carlyle Looks Toward Commercial Aerospace (Reu...
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\ab...Oil and Economy Cloud Stocks' Outlook (Reuters...
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\f...Iraq Halts Oil Exports from Main Southern Pipe...
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...Oil prices soar to all-time record, posing new...
...............
1199951Pakistan's Musharraf Says Won't Quit as Army C...KARACHI (Reuters) - Pakistani President Perve...Pakistan's Musharraf Says Won't Quit as Army C...
1199962Renteria signing a top-shelf dealRed Sox general manager Theo Epstein acknowled...Renteria signing a top-shelf deal Red Sox gene...
1199972Saban not going to Dolphins yetThe Miami Dolphins will put their courtship of...Saban not going to Dolphins yet The Miami Dolp...
1199982Today's NFL gamesPITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...Today's NFL games PITTSBURGH at NY GIANTS Time...
1199992Nets get Carter from RaptorsINDIANAPOLIS -- All-Star Vince Carter was trad...Nets get Carter from Raptors INDIANAPOLIS -- A...

120000 rows × 4 columns

構建詞表

# 定義 Dataset
class AGNewsDataset(Dataset):def __init__(self, dataframe):self.labels = dataframe['Class Index'].tolist()  # 列表數據self.texts = dataframe['text'].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加載數據
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)# 構建詞表
tokenizer = get_tokenizer("basic_english")  # 英文數據,設置英文分詞def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)  # 構建詞表# 構建詞表,設置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Vocab size: 95804
# 查看這些單詞所在詞典的索引
vocab(['here', 'is', 'an', 'example'])  
[475, 21, 30, 5297]
'''  
標簽,原始是字符串類型,現在要轉換成 數字 類型
文本數字化,需要一個函數進行轉換(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x))  # 先分詞。在數字化
label_pipline = lambda x : int(x) - 1   # 標簽轉化為數字# 舉例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]

2、生成數據批次和迭代器

# 采用embeddingbag嵌入方式,故需要構建數據,包括長度、標簽、偏移量
''' 
數據格式:長度(~, 1)
標簽:一維
偏移量:一維
'''
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 標簽列表,注意字符串轉換成數字label_list.append(label_pipline(_label))# 文本列表,注意要轉入tensro數據temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)text_list.append(temp_text)# 偏移量offsets.append(temp_text.size(0))# 全部轉變成tensor變量label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)# 數據加載
batch_size = 16
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)

3、定義與模型

模型定義

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embeddingBag = nn.EmbeddingBag(vocab_size,  # 詞典大小embed_dim,   # 嵌入維度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化權重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange)  # 初始化權重范圍self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置置為0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)

創建模型

# 查看數據類別
train_df.groupby('Class Index').count()
TitleDescriptiontext
Class Index
1300003000030000
2300003000030000
3300003000030000
4300003000030000
class_num = 4
vocab_len = len(vocab)
embed_dim = 64  # 嵌入到64維度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)

4、創建訓練和評估函數

訓練函數

def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求導與反向傳播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_loss

評估函數

def test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss

創建超參數

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)  # 動態調整學習率

5、模型訓練

import copyepochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 動態調整學習率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model)  # 保存模型# 當前學習率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,  epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型參數
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03

6、結果展示

import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False      # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100        #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()

?
在這里插入圖片描述

?

7、預測

model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型評估# 測試句子
test_sentence = "This is a news about Technology"# 轉換為 token
token_ids = vocab(tokenizer(test_sentence))   # 切割分詞--> 詞典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device)  # 轉化為tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 測試,注意:不需要反向求導
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()# 輸出結果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"預測類別: {class_names[predicted_label]}")
預測類別: Science and Technology

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/87528.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/87528.shtml
英文地址,請注明出處:http://en.pswp.cn/web/87528.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu安裝ClickHouse

注&#xff1a;本文章的ubuntu的版本為&#xff1a;ubuntu-20.04.6-live-server-amd64。 Ubuntu&#xff08;在線版&#xff09; 更新軟件源 sudo apt-get update 安裝apt-transport-https 允許apt工具通過https協議下載軟件包。 sudo apt-get install apt-transport-htt…

C++26 下一代C++標準

C++26 將是繼 C++23 之后的下一個 C++ 標準。這個新標準對 C++ 進行了重大改進,很可能像 C++98、C++11 或 C++20 那樣具有劃時代的意義。 一:C++標準回顧 C++ 已經有 40 多年的歷史了。過去這些年里發生了什么?這里給出一個簡化版的答案,直到即將到來的 C++26。 1. C++9…

【MySQL】十六,MySQL窗口函數

在 MySQL 8.0 及以后版本中&#xff0c;窗口函數&#xff08;Window Functions&#xff09;為數據分析和處理提供了強大的工具。窗口函數允許在查詢結果集上執行計算&#xff0c;而不必使用子查詢或連接&#xff0c;這使得某些類型的計算更加高效和簡潔。 語法結構 function_…

微型氣象儀在城市環境的應用

微型氣象儀憑借其體積小、成本低、部署靈活、數據實時性強等特點&#xff0c;在城市環境中得到廣泛應用&#xff0c;能夠為城市規劃、環境管理、公共安全、居民生活等領域提供精細化氣象數據支持。一、核心應用場景1. 城市微氣候監測與優化熱島效應研究場景&#xff1a;在城市不…

【仿muduo庫實現并發服務器】eventloop模塊

仿muduo庫實現并發服務器一.eventloop模塊1.成員變量std::thread::id _thread_id;//線程IDPoller _poll;int _event_fd;std::vector<Function<Function>> _task;TimerWheel _timer_wheel2.EventLoop構造3.針對eventfd的操作4.針對poller的操作5.針對threadID的操作…

Redis 加鎖、解鎖

Redis 加鎖和解鎖的應用 上代碼 應用調用示例 RedisLockEntity lockEntityYlb RedisLockEntity.builder().lockKey(TradeConstants.HP_APP_AMOUNT_LOCK_PREFIX appUser.getAccount()).value(orderId).build();boolean isLockedYlb false;try {if (redisLock.tryLock(lockE…

在 Windows 上為 WSL 增加 root 賬號密碼并通過 Shell 工具連接

1. 為 WSL 設置 root 用戶密碼 在 Windows 上使用 WSL&#xff08;Windows Subsystem for Linux&#xff09;時&#xff0c;默認情況下并沒有啟用 root 賬號的密碼。為了通過 SSH 或其他工具以 root 身份連接到 WSL&#xff0c;我們需要為 root 用戶設置密碼。 設置 root 密碼步…

2730、找到最長的半重復子字符穿

題目&#xff1a; 解答&#xff1a; 窗口為[left&#xff0c;right]&#xff0c;ans為窗口長度&#xff0c;same為子串長度&#xff0c;窗口滿足題設條件&#xff0c;即只含一個連續重復字符&#xff0c;則更新ans&#xff0c;否則從左邊開始一直彈出&#xff0c;直到滿足條件…

MCP Java SDK源碼分析

MCP Java SDK源碼分析 一、引言 在當今人工智能飛速發展的時代&#xff0c;大型語言模型&#xff08;LLMs&#xff09;如GPT - 4、Claude等展現出了強大的語言理解和生成能力。然而&#xff0c;這些模型面臨著一個核心限制&#xff0c;即無法直接訪問外部世界的數據和工具。M…

[Linux]內核如何對信號進行捕捉

要理解Linux中內核如何對信號進行捕捉&#xff0c;我們需要很多前置知識的理解&#xff1a; 內核態和用戶態的區別CPU指令集權限內核態和用戶態之間的切換 由于文章的側重點不同&#xff0c;上面這些知識我會在這篇文章盡量詳細提及&#xff0c;更詳細內容還得請大家查看這篇…

設計模式-觀察者模式、命令模式

觀察者模式Observer&#xff08;觀察者&#xff09;—對象行為型模式定義&#xff1a;定義了一種一對多的依賴關系,讓多個觀察者對象同時監聽某一主題對象,在它的狀態發生變化時,會通知所有的觀察者.先將 Observer A B C 注冊到 Observable &#xff0c;那么當 Observable 狀態…

【Unity筆記01】基于單例模式的簡單UI框架

單例模式的UIManagerusing System.Collections; using System.Collections.Generic; using UnityEngine;public class UIManager {private static UIManager _instance;public Dictionary<string, string> pathDict;public Dictionary<string, GameObject> prefab…

深入解析 OPC UA:工業自動化與物聯網的關鍵技術

在當今快速發展的工業自動化和物聯網&#xff08;IoT&#xff09;領域&#xff0c;數據的無縫交換和集成變得至關重要。OPC UA&#xff08;Open Platform Communications Unified Architecture&#xff09;作為一種開放的、跨平臺的工業通信協議&#xff0c;正在成為這一領域的…

MCP 協議的未來發展趨勢與學習路徑

MCP 協議的未來發展趨勢 6.1 MCP 技術演進與更新 MCP 協議正在快速發展&#xff0c;不斷引入新的功能和改進。根據 2025 年 3 月 26 日發布的協議規范&#xff0c;MCP 的最新版本已經引入了多項重要更新&#xff1a; 1.HTTP Transport 正式轉正&#xff1a;引入 Streamable …

硬件嵌入式學習路線大總結(一):C語言與linux。內功心法——從入門到精通,徹底打通你的任督二脈!

嵌入式工程師學習路線大總結&#xff08;一&#xff09; 引言&#xff1a;C語言——嵌入式領域的“屠龍寶刀”&#xff01; 兄弟們&#xff0c;如果你想在嵌入式領域闖出一片天地&#xff0c;C語言就是你手里那把最鋒利的“屠龍寶刀”&#xff01;它不像Python那樣優雅&#xf…

MCP server資源網站去哪找?國內MCP服務合集平臺有哪些?

在人工智能飛速發展的今天&#xff0c;AI模型與外部世界的交互變得愈發重要。一個好的工具不僅能提升開發效率&#xff0c;還能激發更多的創意。今天&#xff0c;我要給大家介紹一個寶藏平臺——AIbase&#xff08;<https://mcp.aibase.cn/>&#xff09;&#xff0c;一個…

修改Spatial-MLLM項目,使其專注于無人機航拍視頻的空間理解

修改Spatial-MLLM項目&#xff0c;使其專注于無人機航拍視頻的空間理解。以下是修改方案和關鍵代碼實現&#xff1a; 修改思路 輸入處理&#xff1a;將原項目的視頻文本輸入改為單一無人機航拍視頻/圖像輸入問題生成&#xff1a;自動生成空間理解相關的問題&#xff08;無需用戶…

攻防世界-Reverse-insanity

知識點 1.ELF文件逆向 2.IDApro的使用 3.strings的使用 步驟 方法一&#xff1a;IDA 使用exeinfo打開&#xff0c;發現是32位ELF文件&#xff0c;然后用ida32打開。 找到main函數&#xff0c;然后F5反編譯&#xff0c;得到flag。 tip&#xff1a;該程序是根據隨機函數生成…

【openp2p】 學習1:P2PApp和優秀的go跨平臺項目

P2PApp下面給出一個基于 RESTful 風格的 P2PApp 管理方案示例,供二次開發或 API 對接參考。核心思路就是把每個 P2PApp 當成一個可創建、查詢、修改、啟動/停止、刪除的資源來管理。 一、P2PApp 資源模型 P2PApp:id: string # 唯一標識name: string # …

邊緣設備上部署模型的限制之一——顯存占用:模型的參數量只是冰山一角

邊緣設備上部署模型的限制之一——顯存占用&#xff1a;模型的參數量只是冰山一角 在邊緣設備上部署深度學習模型已成為趨勢&#xff0c;但資源限制是其核心挑戰之一。其中&#xff0c;顯存&#xff08;或更廣義的內存&#xff09;占用是開發者們必須仔細考量的重要因素。許多…