- 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
- 🍖 原作者:K同學啊
博主簡介:努力學習的22級本科生一枚 🌟?;探索AI算法,C++,go語言的世界;在迷茫中尋找光芒?🌸
博客主頁:羊小豬~~-CSDN博客
內容簡介:這一篇是NLP的入門項目,以AG_NEW新聞數據為例。
🌸箴言🌸:去尋找理想的“天空“”之城
上一篇內容:【NLP入門系列三】NLP文本嵌入(以Embedding和EmbeddingBag為例)-CSDN博客
?💁??💁??💁??💁?: 如果在conda安裝環境,由于nlp的核心包是torchtext,所以如果把握不好就重新創建一虛擬環境(小編的“難忘”經歷)
文章目錄
- 1、準備
- 數據加載
- 構建詞表
- 2、生成數據批次和迭代器
- 3、定義與模型
- 模型定義
- 創建模型
- 4、創建訓練和評估函數
- 訓練函數
- 評估函數
- 創建超參數
- 5、模型訓練
- 6、結果展示
- 7、預測
🤔 思路
1、準備
AG News 數據集(也叫 AG’s Corpus or AG News Dataset),這是一個廣泛用于自然語言處理(NLP)任務中的文本分類數據集。
基本信息:
- 全稱:AG News
- 來源:來源于 AG’s corpus,由 A. Godin 在 2005 年構建。
- 用途:主要用于短文本多類別分類任務
- 語言:英文
- 類別數:4 類新聞主題
- 訓練樣本數:120,000 條
- 測試樣本數:7,600 條
類別標簽(共 4 類)
標簽 | 含義 |
---|---|
1 | World (世界) |
2 | Sports (體育) |
3 | Business (商業) |
4 | Science and Technology (科技) |
數據加載
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchtext
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 檢查設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
# 加載本地數據
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")# 合并標題和描述數據
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]# 查看數據格式
train_df
Class Index | Title | Description | text | |
---|---|---|---|---|
0 | 3 | Wall St. Bears Claw Back Into the Black (Reuters) | Reuters - Short-sellers, Wall Street's dwindli... | Wall St. Bears Claw Back Into the Black (Reute... |
1 | 3 | Carlyle Looks Toward Commercial Aerospace (Reu... | Reuters - Private investment firm Carlyle Grou... | Carlyle Looks Toward Commercial Aerospace (Reu... |
2 | 3 | Oil and Economy Cloud Stocks' Outlook (Reuters) | Reuters - Soaring crude prices plus worries\ab... | Oil and Economy Cloud Stocks' Outlook (Reuters... |
3 | 3 | Iraq Halts Oil Exports from Main Southern Pipe... | Reuters - Authorities have halted oil export\f... | Iraq Halts Oil Exports from Main Southern Pipe... |
4 | 3 | Oil prices soar to all-time record, posing new... | AFP - Tearaway world oil prices, toppling reco... | Oil prices soar to all-time record, posing new... |
... | ... | ... | ... | ... |
119995 | 1 | Pakistan's Musharraf Says Won't Quit as Army C... | KARACHI (Reuters) - Pakistani President Perve... | Pakistan's Musharraf Says Won't Quit as Army C... |
119996 | 2 | Renteria signing a top-shelf deal | Red Sox general manager Theo Epstein acknowled... | Renteria signing a top-shelf deal Red Sox gene... |
119997 | 2 | Saban not going to Dolphins yet | The Miami Dolphins will put their courtship of... | Saban not going to Dolphins yet The Miami Dolp... |
119998 | 2 | Today's NFL games | PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ... | Today's NFL games PITTSBURGH at NY GIANTS Time... |
119999 | 2 | Nets get Carter from Raptors | INDIANAPOLIS -- All-Star Vince Carter was trad... | Nets get Carter from Raptors INDIANAPOLIS -- A... |
120000 rows × 4 columns
構建詞表
# 定義 Dataset
class AGNewsDataset(Dataset):def __init__(self, dataframe):self.labels = dataframe['Class Index'].tolist() # 列表數據self.texts = dataframe['text'].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加載數據
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)# 構建詞表
tokenizer = get_tokenizer("basic_english") # 英文數據,設置英文分詞def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text) # 構建詞表# 構建詞表,設置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Vocab size: 95804
# 查看這些單詞所在詞典的索引
vocab(['here', 'is', 'an', 'example'])
[475, 21, 30, 5297]
'''
標簽,原始是字符串類型,現在要轉換成 數字 類型
文本數字化,需要一個函數進行轉換(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x)) # 先分詞。在數字化
label_pipline = lambda x : int(x) - 1 # 標簽轉化為數字# 舉例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]
2、生成數據批次和迭代器
# 采用embeddingbag嵌入方式,故需要構建數據,包括長度、標簽、偏移量
'''
數據格式:長度(~, 1)
標簽:一維
偏移量:一維
'''
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 標簽列表,注意字符串轉換成數字label_list.append(label_pipline(_label))# 文本列表,注意要轉入tensro數據temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)text_list.append(temp_text)# 偏移量offsets.append(temp_text.size(0))# 全部轉變成tensor變量label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)# 數據加載
batch_size = 16
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)
3、定義與模型
模型定義
class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embeddingBag = nn.EmbeddingBag(vocab_size, # 詞典大小embed_dim, # 嵌入維度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化權重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange) # 初始化權重范圍self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_() # 偏置置為0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)
創建模型
# 查看數據類別
train_df.groupby('Class Index').count()
Title | Description | text | |
---|---|---|---|
Class Index | |||
1 | 30000 | 30000 | 30000 |
2 | 30000 | 30000 | 30000 |
3 | 30000 | 30000 | 30000 |
4 | 30000 | 30000 | 30000 |
class_num = 4
vocab_len = len(vocab)
embed_dim = 64 # 嵌入到64維度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)
4、創建訓練和評估函數
訓練函數
def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求導與反向傳播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_loss
評估函數
def test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss
創建超參數
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01) # 動態調整學習率
5、模型訓練
import copyepochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 動態調整學習率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model) # 保存模型# 當前學習率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型參數
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03
6、結果展示
import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號
plt.rcParams['figure.dpi'] = 100 #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()
?
?
7、預測
model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型評估# 測試句子
test_sentence = "This is a news about Technology"# 轉換為 token
token_ids = vocab(tokenizer(test_sentence)) # 切割分詞--> 詞典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device) # 轉化為tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 測試,注意:不需要反向求導
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()# 輸出結果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"預測類別: {class_names[predicted_label]}")
預測類別: Science and Technology