目錄
技術細節
1、下載模型
?2、config文件
3、BERT 文本分類數據預處理流程
4、對輸入文本進行分類
?5、計算模型的分類性能指標
6、模型訓練
?7、基于BERT的文本分類預測接口
問題總結
技術細節
1、下載模型
文件名稱--a0_download_model.py
使用?ModelScope?庫從模型倉庫下載?BERT-base-Chinese?預訓練模型,并將其保存到本地指定目錄。
# 模型下載
from modelscope import snapshot_downloadmodel_dir = snapshot_download('google-bert/bert-base-chinese', local_dir=r"D:\src\bert-base-chinese")
?2、config文件
數據加載與保存的路徑
# 根目錄self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始數據路徑self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 類別文檔self.class_path = self.root_path + "01-data/class.txt"# 類別名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")] # 類別名單# 模型訓練保存路徑self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt" # 模型訓練結果保存路徑
加載預訓練Bert模型以及其分詞器和配置文件
self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese" # 預訓練BERT模型的路徑self.bert_model = BertModel.from_pretrained(self.bert_path) # 加載預訓練BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) # BERT模型的分詞器self.bert_config = BertConfig.from_pretrained(self.bert_path) # BERT模型的配置
設置模型參數
# todo 參數self.num_classes = len(self.class_list) # 類別數self.num_epochs = 1 # epoch數self.batch_size = 64 # mini-batch大小self.pad_size = 32 # 每句話處理成的長度(短填長切)self.learning_rate = 5e-5 # 學習率
完整config代碼
import torch
import datetime
from transformers import BertModel, BertTokenizer, BertConfig# 獲取當前日期字符串
current_date = datetime.datetime.now().date().strftime("%Y%m%d")# 配置類
class Config(object):def __init__(self):"""配置類,包含模型和訓練所需的各種參數。"""self.model_name = "bert" # 模型名稱# todo 路徑# 根目錄self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始數據路徑self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 類別文檔self.class_path = self.root_path + "01-data/class.txt"# 類別名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")] # 類別名單# 模型訓練保存路徑self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt" # 模型訓練結果保存路徑# 模型訓練+預測的時候 訓練設備,如果GPU可用,則為cuda,否則為cpuself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese" # 預訓練BERT模型的路徑self.bert_model = BertModel.from_pretrained(self.bert_path) # 加載預訓練BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) # BERT模型的分詞器self.bert_config = BertConfig.from_pretrained(self.bert_path) # BERT模型的配置# todo 參數self.num_classes = len(self.class_list) # 類別數self.num_epochs = 1 # epoch數self.batch_size = 64 # mini-batch大小self.pad_size = 32 # 每句話處理成的長度(短填長切)self.learning_rate = 5e-5 # 學習率# TODO 量化模型存放地址# 注意: 量化的時候模型需要的設備首選是cpuself.bert_model_quantization_model_path = self.root_path + "dm_03_bert/save_models/bert_classifer_quantization_model.pt" # 模型訓練結果保存路徑if __name__ == '__main__':# 測試conf = Config()print(conf.device)print(conf.class_list)print(conf.tokenizer)input_size = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中", "人"])print(input_size)print(conf.bert_model)print(conf.bert_config)
3、BERT 文本分類數據預處理流程
數據加載
def load_raw_data(file_path):"""從指定文件中加載原始數據。處理文本文件,返回(文本, 標簽類別索引)元組列表參數:file_path: 原始文本文件路徑返回:list: 包含(文本, 標簽類別索引)的元組列表,類別為int類型[('體驗2D巔峰 倚天屠龍記十大創新概覽', 8), ('60年鐵樹開花形狀似玉米芯(組圖)', 5)]"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包裝文件讀取迭代器,以便顯示加載數據的進度條for line in tqdm(f, desc=f"加載原始數據{file_path}"):# 移除行兩端的空白字符line = line.strip()# 跳過空行if not line:continue# 將行分割成文本和標簽兩部分text, label = line.split("\t")# 將標簽轉為int作為類別label = int(label)# 將文本和轉換為整數的標簽作為元組添加到數據列表中result.append((text, label))# 返回處理后的列表return result
數據集構建
# todo 2.自定義數據集
class TextDataset(Dataset):# 初始化數據def __init__(self, data_list):self.data_list = data_list# 返回數據集長度def __len__(self):return len(self.data_list)# 根據樣本索引,返回對應的特征和標簽def __getitem__(self, idx):text, label = self.data_list[idx]return text, label
批處理(padding/collate)
每當 DataLoader 從 Dataset 中取出一批batch 的原始數據后, 就會調用 collate_fn 來對這個 batch 進行統一處理(如填充、轉換為張量等)。
batch
?是一個包含多個 (文本, 標簽) 元組的列表使用?
zip(*batch)
?將元組列表"轉置"為兩個元組:一個包含所有文本,一個包含所有標簽例如:
[("text1", 1), ("text2", 2)]
?→?("text1", "text2")
?和?(1, 2)
add_special_tokens=True
: 自動添加 [CLS] 和 [SEP] 等特殊token
padding='max_length'
: 將所有文本填充到固定長度?conf.pad_size
max_length=conf.pad_size
: 設置最大長度
truncation=True
: 如果文本超過最大長度則截斷
return_attention_mask=True
: 返回注意力掩碼
input_ids
: 文本轉換為的數字token ID序列
attention_mask
: 標記哪些位置是實際文本(1),哪些是填充部分(0)
def collate_fn(batch):"""對batch數據進行padding處理參數: batch: 包含(文本, 標簽)元組的batch數據返回: tuple: 包含處理后的input_ids, attention_mask和labels的元組"""# todo 使用zip()將一批batch數據中的(text, label)元組拆分成兩個獨立的元組# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 對文本進行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True, # 默認True,自動添加 [CLS] 和 [SEP]# padding=True,自動填充到數據中的最大長度 padding='max_length':填充到指定的固定長度padding='max_length',max_length=conf.pad_size, # 設定目標長度truncation=True, # 開啟截斷,防止超出模型限制return_attention_mask=True # 請求返回注意力掩碼,以區分輸入中的有效信息和填充信息)# 從文本令牌中提取輸入IDinput_ids = text_tokens['input_ids']# 從文本令牌中提取注意力掩碼attention_mask = text_tokens['attention_mask']# 將輸入的token ID列表轉換為張量input_ids = torch.tensor(input_ids)# 將注意力掩碼列表轉換為張量attention_mask = torch.tensor(attention_mask)# 將標簽列表轉換為張量labels = torch.tensor(labels)# 返回轉換后的張量元組return input_ids, attention_mask, labels
DataLoader 封裝
def build_dataloader():# 加載原始數據train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 構建訓練集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 構建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloader
完整代碼
# 加載數據工具類
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TextDataset
from transformers.utils import PaddingStrategy
from config import Config
import time# 加載配置
conf = Config()# todo 加載并處理原始數據
def load_raw_data(file_path):"""從指定文件中加載原始數據。處理文本文件,返回(文本, 標簽, 類別)元組列表參數:file_path: 文本文件路徑返回:list: 包含(文本, 標簽, 類別)的元組列表,類別為int類型"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包裝文件讀取迭代器,以便顯示加載數據的進度條for line in tqdm(f, desc="加載數據..."):# 移除行兩端的空白字符line = line.strip()# 跳過空行if not line:continue# 將行分割成文本和標簽兩部分text, label = line.split("\t")# 將標簽轉為int作為類別label = int(label)# 將文本和轉換為整數的標簽作為元組添加到數據列表中result.append((text, label))# 返回處理后的列表return result# todo 自定義數據集
class TextDataset(Dataset):# 初始化數據def __init__(self, data_list):self.data_list = data_list# 返回數據集長度def __len__(self):return len(self.data_list)# 根據樣本索引,返回對應的特征和標簽def __getitem__(self, idx):text, label = self.data_list[idx]return text, label# todo 批量處理數據
# 每當 DataLoader 從 Dataset 中取出一個 batch 的原始數據后,
# 就會調用 collate_fn 來對這個 batch 進行統一處理(如填充、轉換為張量等)。
def collate_fn(batch):"""對batch數據進行padding處理參數: batch: 包含(文本, 標簽)元組的batch數據返回: tuple: 包含處理后的input_ids, attention_mask和labels的元組"""# todo 使用zip()將一批batch數據中的(text, label)元組拆分成兩個獨立的元組# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 對文本進行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True, # 默認True,自動添加 [CLS] 和 [SEP]padding='max_length', # 固定長度max_length=conf.pad_size, # 設定目標長度truncation=True, # 開啟截斷,防止超出模型限制return_attention_mask=True # 請求返回注意力掩碼,以區分輸入中的有效信息和填充信息)# 從文本令牌中提取輸入IDinput_ids = text_tokens['input_ids']# 從文本令牌中提取注意力掩碼attention_mask = text_tokens['attention_mask']# 將輸入的token ID列表轉換為張量input_ids = torch.tensor(input_ids)# 將注意力掩碼列表轉換為張量attention_mask = torch.tensor(attention_mask)# 將標簽列表轉換為張量labels = torch.tensor(labels)# 返回轉換后的張量元組return input_ids, attention_mask, labels# todo 構建dataloader
def build_dataloader():# 加載原始數據train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 構建訓練集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 構建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloaderif __name__ == '__main__':# 測試load_raw_data方法data_list = load_raw_data(conf.dev_datapath)print(data_list[:10])# 測試TextDataset類dataset = TextDataset(data_list)print(dataset[0])print(dataset[1])# 測試build_dataloader方法train_dataloader, dev_dataloader, test_dataloader = build_dataloader()print(len(train_dataloader))print(len(dev_dataloader))print(len(test_dataloader))# 測試collate_fn方法"""for i, batch in enumerate(train_dataloader)流程如下:1.DataLoader 從你的 Dataset 中取出一組索引;2.使用這些索引調用 Dataset.__getitem__ 獲取原始樣本;3.將這一組樣本組成一個 batch(通常是 (text, label) 元組的列表);4.自動調用你傳入的 collate_fn 函數來處理這個 batch 數據;5.返回處理后的 batch(如 input_ids, attention_mask, labels)供模型使用。"""for i, batch in enumerate(train_dataloader):print(len(batch))print(i)input_ids, attention_mask, labels = batch# print("input_ids: ", input_ids.tolist())print("input_ids.shape: ", input_ids.shape)# print("attention_mask: ", attention_mask.tolist())print("attention_mask.shape: ", attention_mask.shape)# print("labels: ", labels.tolist())print("labels.shape: ", labels.shape)break
4、對輸入文本進行分類
-
流程:
-
加載預訓練 BERT 模型(
BertModel
)作為特征提取器。 -
添加全連接層(
nn.Linear
)進行類別預測。 -
提供測試代碼,演示從文本輸入到預測結果的全過程。
-
outputs
?是一個包含多個輸出的元組,對于分類任務我們主要關注:
last_hidden_state
: 序列中每個 token 的隱藏狀態 (shape:?[batch_size, sequence_length, hidden_size]
)
pooler_output
: 經過池化的整個序列的表示 (shape:?[batch_size, hidden_size]
)
假設:
batch_size = 32
sequence_length = 128
?(經過 padding/truncation 后的長度)
hidden_size = 768
?(BERT-base 的隱藏層大小)
num_classes = 10
數據流過程:
輸入?
input_ids
:?[32, 128]
輸入?
attention_mask
:?[32, 128]
BERT 輸出?
pooler_output
:?[32, 768]
全連接層輸出?
logits
:?[32, 10]
import torch
import torch.nn as nn
from transformers import BertModel
from config import Config# 加載配置
config = Config()# 定義bert模型
class BertClassifier(nn.Module):def __init__(self):# 初始化父類類的構造函數super().__init__()# 下面的BertModel是從transformers庫中加載的預訓練模型# config.bert_path是預訓練模型的路徑self.bert = BertModel.from_pretrained(config.bert_path)# 定義全連接層(fc),用于分類任務# 輸入尺寸是Bert模型隱藏層的大小,即768(對于Base模型)# 輸出尺寸是類別數量,由config.num_classes指定self.fc = nn.Linear(config.bert_config.hidden_size, config.num_classes)def forward(self, input_ids, attention_mask):# 使用BERT模型處理輸入的token ID和注意力掩碼,獲取BERT模型的輸出# outputs是: _,pooledoutputs = self.bert(input_ids=input_ids, # 輸入的token IDattention_mask=attention_mask # 注意力掩碼用于區分有效token和填充token)# print(outputs) # 觀察結果# 通過全連接層對BERT模型的輸出進行分類logits = self.fc(outputs.pooler_output)# 返回分類的logits(未歸一化的預測分數)return logits# 測試以上模型
if __name__ == '__main__':# 測試model = BertClassifier()# 加載from transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained(config.bert_path)# 示例文本texts = ["我喜歡你", "今天天氣真好"]# 編碼文本encoded_inputs = tokenizer(texts,# padding=True, # 所有的填充到文本最大長度padding="max_length", # 所有的填充到指定的max_length長度truncation=True, # 如果超出指定的max_length長度,則截斷max_length=10,return_tensors="pt" # 返回 pytorch 張量,"pt" 時,分詞器會將輸入文本轉換為模型可接受的格式)# 獲取 input_ids 和 attention_maskinput_ids = encoded_inputs["input_ids"]attention_mask = encoded_inputs["attention_mask"]print('input_ids:', input_ids)print('attention_mask:', attention_mask)print('======================================')# 預測logits = model(input_ids, attention_mask)print(logits) # 每一行對應一個樣本,每個數字表示該樣本屬于某一類別的“得分”(logit),沒有經過 softmax 歸一化。print('-------------------------------')# 獲取預測概率probs = torch.softmax(logits, dim=-1)print(probs) # 歸一化后該樣本屬于某類的概率(范圍在 0~1 之間),概率最高的就是預測結果print('-------------------------------')preds = torch.argmax(probs, dim=-1)print(preds) # 得到每個樣本的預測類別。表示兩個輸入文本被模型預測為類別 6(從 0 開始計數)。
?5、計算模型的分類性能指標
功能:傳入模型、數據、設備,返回分類報告
完整代碼:
import torch
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdmdef model2dev(model, data_loader, device):"""在驗證或測試集上評估 BERT 分類模型的性能。參數:model (nn.Module): BERT 分類模型。data_loader (DataLoader): 數據加載器(驗證或測試集)。device (str): 設備("cuda" 或 "cpu")。返回:tuple: (分類報告, F1 分數, 準確度, 精確度,召回率)- report: 分類報告(包含每個類別的精確度、召回率、F1 分數等)。- f1score: 微平均 F1 分數。- accuracy: 準確度。- precision: 微平均精確度- recall: 微平均召回率"""# todo 1. 設置模型為評估模式(禁用 dropout,并改變batch_norm行為)model.eval()# 2. 初始化列表,存儲預測結果和真實標簽all_preds, all_labels = [], []# 3. todo torch.no_grad()禁用梯度計算以提高效率并減少內存占用with torch.no_grad():# 4. 遍歷數據加載器,逐批次進行預測for i, batch in enumerate(tqdm(data_loader, desc="驗證集評估中...")):# 4.1 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)labels = labels.to(device)# 4.2 前向傳播:模型預測outputs = model(input_ids, attention_mask=attention_mask)# 4.3 獲取預測結果(最大 logits分數 對應的類別)y_pred_list = torch.argmax(outputs, dim=1)# 4.4 存儲預測和真實標簽all_preds.extend(y_pred_list.cpu().tolist())all_labels.extend(labels.cpu().tolist())# 5. 計算分類報告、F1 分數、準確率,精確率,召回率report = classification_report(all_labels, all_preds)f1score = f1_score(all_labels, all_preds, average='macro')accuracy = accuracy_score(all_labels, all_preds)precision = precision_score(all_labels, all_preds, average='macro')recall = recall_score(all_labels, all_preds, average='macro')# 6. 返回評估結果return report, f1score, accuracy, precision, recall
6、模型訓練
口訣:15241
加載配置文件和參數(1)
# todo 加載配置對象,包含模型參數、路徑等
conf = Config()
# todo 導入數據處理工具類
from a1_dataloader_utils import build_dataloader
# todo 導入bert模型
from a2_bert_classifer_model import BertClassifier
準備數據(4)
# todo 1、準備數據train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、準備模型# 2.1初始化bert分類模型model = BertClassifier()# 2.2將模型移動到指定的設備model.to(conf.device)# todo 3.準備損失函數loss_fn = nn.CrossEntropyLoss()# todo 4.準備優化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)
模型訓練
外層遍歷輪次,內層遍歷批次
前向傳播:模型預測、計算損失
反向傳播:梯度清零、計算梯度、參數更新
# todo 5.開始訓練模型# 初始化F1分數,用于保存最好的模型best_f1 = 0.0# todo 5.1 外層循環遍歷每個訓練輪次# (每次需要設置訓練模式,累計損失,預存訓練集測試和真實標簽)for epoch in range(conf.num_epochs):# 設置模型為訓練模式model.train()# 初始化累計損失,初始化訓練集預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 內層循環遍歷訓練DataLoader每個批次for i, batch in enumerate(tqdm(train_dataloader, desc="訓練集訓練中...")):# 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向傳播:模型預測logits = model(input_ids, attention_mask)# todo 計算損失loss = loss_fn(logits, labels)# 累計損失total_loss += loss.item()# todo 獲取預測結果(最大logits對應的類別)y_pred_list = torch.argmax(logits, dim=1)# todo 存儲預測和真實標簽,用于計算訓練集指標train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向傳播:計算梯度loss.backward()# todo 參數更新:根據梯度更新模型參數optimizer.step()
驗證評估
# todo 每10個批次或一個輪次結束,計算訓練集指標if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 計算準確率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 獲取batch_count,并計算平均損失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印訓練信息print(f"\n輪次: {epoch + 1}, 批次: {i + 1}, 損失: {avg_loss:.4f}, acc準確率:{acc:.4f}, f1分數:{f1:.4f}")# todo 清空累計損失和預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 每100個批次或一個輪次結束,計算驗證集指標,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 計算在測試集的評估報告,準確率,精確率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("驗證集評估報告:\n", report)print(f"驗證集的f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# todo 將模型再設置為訓練模式model.train()# todo 如果驗證F1分數優于歷史最佳,保存模型if f1score > best_f1:# 更新歷史最佳F1分數best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 當前f1分數:", best_f1)
完整代碼
import torch
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
from model2dev_utils import model2dev
from config import Config
from a1_dataloader_utils import build_dataloader
# 忽略的警告信息
import warningswarnings.filterwarnings("ignore")
# todo 導入bert模型
from a2_bert_classifer_model import BertClassifier# 加載配置對象,包含模型參數、路徑等
conf = Config()def model2train():"""訓練 BERT 分類模型并在驗證集上評估,保存最佳模型。參數:無顯式參數,所有配置通過全局 conf 對象獲取。返回:無返回值,訓練過程中保存最佳模型到指定路徑。"""# todo 1、準備數據train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、準備模型# 2.1初始化bert分類模型model = BertClassifier()# 2.2將模型移動到指定的設備model.to(conf.device)# todo 3.準備損失函數loss_fn = nn.CrossEntropyLoss()# todo 4.準備優化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)# todo 5.開始訓練模型# 初始化F1分數,用于保存最好的模型best_f1 = 0.0# todo 5.1 外層循環遍歷每個訓練輪次# (每次需要設置訓練模式,累計損失,預存訓練集測試和真實標簽)for epoch in range(conf.num_epochs):# 設置模型為訓練模式model.train()# 初始化累計損失,初始化訓練集預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 內層循環遍歷訓練DataLoader每個批次for i, batch in enumerate(tqdm(train_dataloader, desc="訓練集訓練中...")):# 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向傳播:模型預測outputs = model(input_ids, attention_mask)# todo 計算損失loss = loss_fn(outputs, labels)# todo 累計損失total_loss += loss.item()# todo 獲取預測結果(最大logits對應的類別)y_pred_list = torch.argmax(outputs, dim=1)# todo 存儲預測和真實標簽,用于計算訓練集指標train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向傳播:計算梯度loss.backward()# todo 參數更新:根據梯度更新模型參數optimizer.step()# todo 每10個批次或一個輪次結束,計算訓練集指標if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 計算準確率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 獲取batch_count,并計算平均損失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印訓練信息print(f"\n輪次: {epoch + 1}, 批次: {i + 1}, 損失: {avg_loss:.4f}, acc準確率:{acc:.4f}, f1分數:{f1:.4f}")# todo 清空累計損失和預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 每100個批次或一個輪次結束,計算驗證集指標,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 計算在測試集的評估報告,準確率,精確率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("驗證集評估報告:\n", report)print(f"驗證集f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# 將模型設置為訓練模式model.train()# todo 如果驗證F1分數優于歷史最佳,保存模型if f1score > best_f1:# 更新歷史最佳F1分數best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 當前f1分數:", best_f1)if __name__ == '__main__':model2train()
?7、基于BERT的文本分類預測接口
-
加載預訓練好的BERT分類模型(
BertClassifier
)。 -
對輸入文本進行分詞、編碼、轉換為張量。
-
使用模型預測文本類別。
-
返回類別名稱(如?
"education"
)
import torch
from a2_bert_classifer_model import BertClassifier
from config import Config# 加載配置
conf = Config()# todo 準備模型
model = BertClassifier()
# todo 加載模型參數
model.load_state_dict(torch.load(conf.model_save_path))
# todo 添加模型到指定設備
model.to(conf.device)
# todo 設置模型為評估模式
model.eval()# TODO 定義predict_fun函數預測函數
def predict_fun(data_dict):"""根據用戶錄入數據,返回分類信息:param 參數 data_dict: {"text":"狀元心經:考前一周重點是回顧和整理"}:return: 返回 data_dict: {"text":"狀元心經:考前一周重點是回顧和整理", "pred_class":"education"}"""# todo 獲取文本text = data_dict['text']# todo 將文本轉為idtext_tokens = conf.tokenizer.batch_encode_plus([text],padding="max_length",max_length=conf.pad_size,pad_to_max_length=True)# todo 獲取input_ids和attention_maskinput_ids = text_tokens['input_ids']attention_mask = text_tokens['attention_mask']# todo 將input_ids和attention_mask轉為tensor, 并指定到設備input_ids = torch.tensor(input_ids).to(conf.device)attention_mask = torch.tensor(attention_mask).to(conf.device)# todo 設置不進行梯度計算(在該上下文中禁用梯度計算,提升推理速度并減少內存占用)with torch.no_grad():# 前向傳播(模型預測)output = model(input_ids, attention_mask)# 獲取預測類別索引張量output = torch.argmax(output, dim=1)# 獲取預測類別索引標量pred_idx = output.item()# 獲取類別名稱pred_class = conf.class_list[pred_idx]print(pred_class)# 將預測結果添加到data_dict中data_dict['pred_class'] = pred_class# 返回data_dictreturn data_dictif __name__ == '__main__':data_dict = {'text': '狀元心經:考前一周重點是回顧和整理'}print(predict_fun(data_dict))
問題總結
持續更新中......