Bert項目--新聞標題文本分類

目錄

技術細節

1、下載模型

?2、config文件

3、BERT 文本分類數據預處理流程

4、對輸入文本進行分類

?5、計算模型的分類性能指標

6、模型訓練

?7、基于BERT的文本分類預測接口

問題總結


技術細節

1、下載模型

文件名稱--a0_download_model.py

使用?ModelScope?庫從模型倉庫下載?BERT-base-Chinese?預訓練模型,并將其保存到本地指定目錄。

# 模型下載
from modelscope import snapshot_downloadmodel_dir = snapshot_download('google-bert/bert-base-chinese', local_dir=r"D:\src\bert-base-chinese")

?2、config文件

數據加載與保存的路徑

# 根目錄self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始數據路徑self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 類別文檔self.class_path = self.root_path + "01-data/class.txt"# 類別名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]  # 類別名單# 模型訓練保存路徑self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt"  # 模型訓練結果保存路徑

加載預訓練Bert模型以及其分詞器和配置文件

        self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese"  # 預訓練BERT模型的路徑self.bert_model = BertModel.from_pretrained(self.bert_path)  # 加載預訓練BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)  # BERT模型的分詞器self.bert_config = BertConfig.from_pretrained(self.bert_path)  # BERT模型的配置

設置模型參數

        # todo 參數self.num_classes = len(self.class_list)  # 類別數self.num_epochs = 1  # epoch數self.batch_size = 64  # mini-batch大小self.pad_size = 32  # 每句話處理成的長度(短填長切)self.learning_rate = 5e-5  # 學習率

完整config代碼

import torch
import datetime
from transformers import BertModel, BertTokenizer, BertConfig# 獲取當前日期字符串
current_date = datetime.datetime.now().date().strftime("%Y%m%d")# 配置類
class Config(object):def __init__(self):"""配置類,包含模型和訓練所需的各種參數。"""self.model_name = "bert"  # 模型名稱# todo 路徑# 根目錄self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始數據路徑self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 類別文檔self.class_path = self.root_path + "01-data/class.txt"# 類別名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]  # 類別名單# 模型訓練保存路徑self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt"  # 模型訓練結果保存路徑# 模型訓練+預測的時候  訓練設備,如果GPU可用,則為cuda,否則為cpuself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese"  # 預訓練BERT模型的路徑self.bert_model = BertModel.from_pretrained(self.bert_path)  # 加載預訓練BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)  # BERT模型的分詞器self.bert_config = BertConfig.from_pretrained(self.bert_path)  # BERT模型的配置# todo 參數self.num_classes = len(self.class_list)  # 類別數self.num_epochs = 1  # epoch數self.batch_size = 64  # mini-batch大小self.pad_size = 32  # 每句話處理成的長度(短填長切)self.learning_rate = 5e-5  # 學習率# TODO 量化模型存放地址# 注意: 量化的時候模型需要的設備首選是cpuself.bert_model_quantization_model_path = self.root_path + "dm_03_bert/save_models/bert_classifer_quantization_model.pt"  # 模型訓練結果保存路徑if __name__ == '__main__':# 測試conf = Config()print(conf.device)print(conf.class_list)print(conf.tokenizer)input_size = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中", "人"])print(input_size)print(conf.bert_model)print(conf.bert_config)

3、BERT 文本分類數據預處理流程

數據加載

def load_raw_data(file_path):"""從指定文件中加載原始數據。處理文本文件,返回(文本, 標簽類別索引)元組列表參數:file_path: 原始文本文件路徑返回:list: 包含(文本, 標簽類別索引)的元組列表,類別為int類型[('體驗2D巔峰 倚天屠龍記十大創新概覽', 8), ('60年鐵樹開花形狀似玉米芯(組圖)', 5)]"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包裝文件讀取迭代器,以便顯示加載數據的進度條for line in tqdm(f, desc=f"加載原始數據{file_path}"):# 移除行兩端的空白字符line = line.strip()# 跳過空行if not line:continue# 將行分割成文本和標簽兩部分text, label = line.split("\t")# 將標簽轉為int作為類別label = int(label)# 將文本和轉換為整數的標簽作為元組添加到數據列表中result.append((text, label))# 返回處理后的列表return result

數據集構建

# todo 2.自定義數據集
class TextDataset(Dataset):# 初始化數據def __init__(self, data_list):self.data_list = data_list# 返回數據集長度def __len__(self):return len(self.data_list)# 根據樣本索引,返回對應的特征和標簽def __getitem__(self, idx):text, label = self.data_list[idx]return text, label

批處理(padding/collate)

每當 DataLoader 從 Dataset 中取出一批batch 的原始數據后,
就會調用 collate_fn 來對這個 batch 進行統一處理(如填充、轉換為張量等)。
  • batch?是一個包含多個 (文本, 標簽) 元組的列表

  • 使用?zip(*batch)?將元組列表"轉置"為兩個元組:一個包含所有文本,一個包含所有標簽

  • 例如:[("text1", 1), ("text2", 2)]?→?("text1", "text2")?和?(1, 2)

  • add_special_tokens=True: 自動添加 [CLS] 和 [SEP] 等特殊token

  • padding='max_length': 將所有文本填充到固定長度?conf.pad_size

  • max_length=conf.pad_size: 設置最大長度

  • truncation=True: 如果文本超過最大長度則截斷

  • return_attention_mask=True: 返回注意力掩碼

  • input_ids: 文本轉換為的數字token ID序列

  • attention_mask: 標記哪些位置是實際文本(1),哪些是填充部分(0)

def collate_fn(batch):"""對batch數據進行padding處理參數: batch: 包含(文本, 標簽)元組的batch數據返回: tuple: 包含處理后的input_ids, attention_mask和labels的元組"""# todo 使用zip()將一批batch數據中的(text, label)元組拆分成兩個獨立的元組# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 對文本進行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True,  # 默認True,自動添加 [CLS] 和 [SEP]# padding=True,自動填充到數據中的最大長度       padding='max_length':填充到指定的固定長度padding='max_length',max_length=conf.pad_size,  # 設定目標長度truncation=True,  # 開啟截斷,防止超出模型限制return_attention_mask=True  # 請求返回注意力掩碼,以區分輸入中的有效信息和填充信息)# 從文本令牌中提取輸入IDinput_ids = text_tokens['input_ids']# 從文本令牌中提取注意力掩碼attention_mask = text_tokens['attention_mask']# 將輸入的token ID列表轉換為張量input_ids = torch.tensor(input_ids)# 將注意力掩碼列表轉換為張量attention_mask = torch.tensor(attention_mask)# 將標簽列表轉換為張量labels = torch.tensor(labels)# 返回轉換后的張量元組return input_ids, attention_mask, labels

DataLoader 封裝

def build_dataloader():# 加載原始數據train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 構建訓練集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 構建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloader

完整代碼

# 加載數據工具類
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TextDataset
from transformers.utils import PaddingStrategy
from config import Config
import time# 加載配置
conf = Config()# todo 加載并處理原始數據
def load_raw_data(file_path):"""從指定文件中加載原始數據。處理文本文件,返回(文本, 標簽, 類別)元組列表參數:file_path: 文本文件路徑返回:list: 包含(文本, 標簽, 類別)的元組列表,類別為int類型"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包裝文件讀取迭代器,以便顯示加載數據的進度條for line in tqdm(f, desc="加載數據..."):# 移除行兩端的空白字符line = line.strip()# 跳過空行if not line:continue# 將行分割成文本和標簽兩部分text, label = line.split("\t")# 將標簽轉為int作為類別label = int(label)# 將文本和轉換為整數的標簽作為元組添加到數據列表中result.append((text, label))# 返回處理后的列表return result# todo 自定義數據集
class TextDataset(Dataset):# 初始化數據def __init__(self, data_list):self.data_list = data_list# 返回數據集長度def __len__(self):return len(self.data_list)# 根據樣本索引,返回對應的特征和標簽def __getitem__(self, idx):text, label = self.data_list[idx]return text, label# todo 批量處理數據
# 每當 DataLoader 從 Dataset 中取出一個 batch 的原始數據后,
# 就會調用 collate_fn 來對這個 batch 進行統一處理(如填充、轉換為張量等)。
def collate_fn(batch):"""對batch數據進行padding處理參數: batch: 包含(文本, 標簽)元組的batch數據返回: tuple: 包含處理后的input_ids, attention_mask和labels的元組"""# todo 使用zip()將一批batch數據中的(text, label)元組拆分成兩個獨立的元組# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 對文本進行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True,  # 默認True,自動添加 [CLS] 和 [SEP]padding='max_length',  # 固定長度max_length=conf.pad_size,  # 設定目標長度truncation=True,  # 開啟截斷,防止超出模型限制return_attention_mask=True  # 請求返回注意力掩碼,以區分輸入中的有效信息和填充信息)# 從文本令牌中提取輸入IDinput_ids = text_tokens['input_ids']# 從文本令牌中提取注意力掩碼attention_mask = text_tokens['attention_mask']# 將輸入的token ID列表轉換為張量input_ids = torch.tensor(input_ids)# 將注意力掩碼列表轉換為張量attention_mask = torch.tensor(attention_mask)# 將標簽列表轉換為張量labels = torch.tensor(labels)# 返回轉換后的張量元組return input_ids, attention_mask, labels# todo 構建dataloader
def build_dataloader():# 加載原始數據train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 構建訓練集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 構建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloaderif __name__ == '__main__':# 測試load_raw_data方法data_list = load_raw_data(conf.dev_datapath)print(data_list[:10])# 測試TextDataset類dataset = TextDataset(data_list)print(dataset[0])print(dataset[1])# 測試build_dataloader方法train_dataloader, dev_dataloader, test_dataloader = build_dataloader()print(len(train_dataloader))print(len(dev_dataloader))print(len(test_dataloader))# 測試collate_fn方法"""for i, batch in enumerate(train_dataloader)流程如下:1.DataLoader 從你的 Dataset 中取出一組索引;2.使用這些索引調用 Dataset.__getitem__ 獲取原始樣本;3.將這一組樣本組成一個 batch(通常是 (text, label) 元組的列表);4.自動調用你傳入的 collate_fn 函數來處理這個 batch 數據;5.返回處理后的 batch(如 input_ids, attention_mask, labels)供模型使用。"""for i, batch in enumerate(train_dataloader):print(len(batch))print(i)input_ids, attention_mask, labels = batch# print("input_ids: ", input_ids.tolist())print("input_ids.shape: ", input_ids.shape)# print("attention_mask: ", attention_mask.tolist())print("attention_mask.shape: ", attention_mask.shape)# print("labels: ", labels.tolist())print("labels.shape: ", labels.shape)break

4、對輸入文本進行分類

  • 流程

    1. 加載預訓練 BERT 模型(BertModel)作為特征提取器。

    2. 添加全連接層(nn.Linear)進行類別預測。

    3. 提供測試代碼,演示從文本輸入到預測結果的全過程。

  • outputs?是一個包含多個輸出的元組,對于分類任務我們主要關注:

    • last_hidden_state: 序列中每個 token 的隱藏狀態 (shape:?[batch_size, sequence_length, hidden_size])

    • pooler_output: 經過池化的整個序列的表示 (shape:?[batch_size, hidden_size])

假設:

  • batch_size = 32

  • sequence_length = 128?(經過 padding/truncation 后的長度)

  • hidden_size = 768?(BERT-base 的隱藏層大小)

  • num_classes = 10

數據流過程:

  1. 輸入?input_ids:?[32, 128]

  2. 輸入?attention_mask:?[32, 128]

  3. BERT 輸出?pooler_output:?[32, 768]

  4. 全連接層輸出?logits:?[32, 10]

import torch
import torch.nn as nn
from transformers import BertModel
from config import Config# 加載配置
config = Config()# 定義bert模型
class BertClassifier(nn.Module):def __init__(self):# 初始化父類類的構造函數super().__init__()# 下面的BertModel是從transformers庫中加載的預訓練模型# config.bert_path是預訓練模型的路徑self.bert = BertModel.from_pretrained(config.bert_path)# 定義全連接層(fc),用于分類任務# 輸入尺寸是Bert模型隱藏層的大小,即768(對于Base模型)# 輸出尺寸是類別數量,由config.num_classes指定self.fc = nn.Linear(config.bert_config.hidden_size, config.num_classes)def forward(self, input_ids, attention_mask):# 使用BERT模型處理輸入的token ID和注意力掩碼,獲取BERT模型的輸出# outputs是: _,pooledoutputs = self.bert(input_ids=input_ids,  # 輸入的token IDattention_mask=attention_mask  # 注意力掩碼用于區分有效token和填充token)# print(outputs) # 觀察結果# 通過全連接層對BERT模型的輸出進行分類logits = self.fc(outputs.pooler_output)# 返回分類的logits(未歸一化的預測分數)return logits# 測試以上模型
if __name__ == '__main__':# 測試model = BertClassifier()# 加載from transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained(config.bert_path)# 示例文本texts = ["我喜歡你", "今天天氣真好"]# 編碼文本encoded_inputs = tokenizer(texts,# padding=True,  #  所有的填充到文本最大長度padding="max_length", # 所有的填充到指定的max_length長度truncation=True, # 如果超出指定的max_length長度,則截斷max_length=10,return_tensors="pt"  # 返回 pytorch 張量,"pt" 時,分詞器會將輸入文本轉換為模型可接受的格式)# 獲取 input_ids 和 attention_maskinput_ids = encoded_inputs["input_ids"]attention_mask = encoded_inputs["attention_mask"]print('input_ids:', input_ids)print('attention_mask:', attention_mask)print('======================================')# 預測logits = model(input_ids, attention_mask)print(logits)  # 每一行對應一個樣本,每個數字表示該樣本屬于某一類別的“得分”(logit),沒有經過 softmax 歸一化。print('-------------------------------')# 獲取預測概率probs = torch.softmax(logits, dim=-1)print(probs)  # 歸一化后該樣本屬于某類的概率(范圍在 0~1 之間),概率最高的就是預測結果print('-------------------------------')preds = torch.argmax(probs, dim=-1)print(preds)  # 得到每個樣本的預測類別。表示兩個輸入文本被模型預測為類別 6(從 0 開始計數)。

?5、計算模型的分類性能指標

功能:傳入模型、數據、設備,返回分類報告

完整代碼:

import torch
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdmdef model2dev(model, data_loader, device):"""在驗證或測試集上評估 BERT 分類模型的性能。參數:model (nn.Module): BERT 分類模型。data_loader (DataLoader): 數據加載器(驗證或測試集)。device (str): 設備("cuda" 或 "cpu")。返回:tuple: (分類報告, F1 分數, 準確度, 精確度,召回率)- report: 分類報告(包含每個類別的精確度、召回率、F1 分數等)。- f1score: 微平均 F1 分數。- accuracy: 準確度。- precision: 微平均精確度- recall: 微平均召回率"""# todo 1. 設置模型為評估模式(禁用 dropout,并改變batch_norm行為)model.eval()# 2. 初始化列表,存儲預測結果和真實標簽all_preds, all_labels = [], []# 3. todo torch.no_grad()禁用梯度計算以提高效率并減少內存占用with torch.no_grad():# 4. 遍歷數據加載器,逐批次進行預測for i, batch in enumerate(tqdm(data_loader, desc="驗證集評估中...")):# 4.1 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)labels = labels.to(device)# 4.2 前向傳播:模型預測outputs = model(input_ids, attention_mask=attention_mask)# 4.3 獲取預測結果(最大 logits分數 對應的類別)y_pred_list = torch.argmax(outputs, dim=1)# 4.4 存儲預測和真實標簽all_preds.extend(y_pred_list.cpu().tolist())all_labels.extend(labels.cpu().tolist())# 5. 計算分類報告、F1 分數、準確率,精確率,召回率report = classification_report(all_labels, all_preds)f1score = f1_score(all_labels, all_preds, average='macro')accuracy = accuracy_score(all_labels, all_preds)precision = precision_score(all_labels, all_preds, average='macro')recall = recall_score(all_labels, all_preds, average='macro')# 6. 返回評估結果return report, f1score, accuracy, precision, recall

6、模型訓練

口訣:15241

加載配置文件和參數(1)

# todo 加載配置對象,包含模型參數、路徑等
conf = Config()
# todo 導入數據處理工具類
from a1_dataloader_utils import build_dataloader
# todo 導入bert模型
from a2_bert_classifer_model import BertClassifier

準備數據(4)

    # todo 1、準備數據train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、準備模型# 2.1初始化bert分類模型model = BertClassifier()# 2.2將模型移動到指定的設備model.to(conf.device)# todo 3.準備損失函數loss_fn = nn.CrossEntropyLoss()# todo 4.準備優化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)

模型訓練

外層遍歷輪次,內層遍歷批次

前向傳播:模型預測、計算損失

反向傳播:梯度清零、計算梯度、參數更新

# todo 5.開始訓練模型# 初始化F1分數,用于保存最好的模型best_f1 = 0.0# todo 5.1 外層循環遍歷每個訓練輪次#  (每次需要設置訓練模式,累計損失,預存訓練集測試和真實標簽)for epoch in range(conf.num_epochs):# 設置模型為訓練模式model.train()# 初始化累計損失,初始化訓練集預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 內層循環遍歷訓練DataLoader每個批次for i, batch in enumerate(tqdm(train_dataloader, desc="訓練集訓練中...")):# 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向傳播:模型預測logits = model(input_ids, attention_mask)# todo 計算損失loss = loss_fn(logits, labels)# 累計損失total_loss += loss.item()# todo 獲取預測結果(最大logits對應的類別)y_pred_list = torch.argmax(logits, dim=1)# todo 存儲預測和真實標簽,用于計算訓練集指標train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向傳播:計算梯度loss.backward()# todo 參數更新:根據梯度更新模型參數optimizer.step()

驗證評估

            # todo 每10個批次或一個輪次結束,計算訓練集指標if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 計算準確率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 獲取batch_count,并計算平均損失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印訓練信息print(f"\n輪次: {epoch + 1}, 批次: {i + 1}, 損失: {avg_loss:.4f}, acc準確率:{acc:.4f}, f1分數:{f1:.4f}")# todo 清空累計損失和預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 每100個批次或一個輪次結束,計算驗證集指標,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 計算在測試集的評估報告,準確率,精確率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("驗證集評估報告:\n", report)print(f"驗證集的f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# todo 將模型再設置為訓練模式model.train()# todo 如果驗證F1分數優于歷史最佳,保存模型if f1score > best_f1:# 更新歷史最佳F1分數best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 當前f1分數:", best_f1)

完整代碼

import torch
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
from model2dev_utils import model2dev
from config import Config
from a1_dataloader_utils import build_dataloader
# 忽略的警告信息
import warningswarnings.filterwarnings("ignore")
# todo 導入bert模型
from a2_bert_classifer_model import BertClassifier# 加載配置對象,包含模型參數、路徑等
conf = Config()def model2train():"""訓練 BERT 分類模型并在驗證集上評估,保存最佳模型。參數:無顯式參數,所有配置通過全局 conf 對象獲取。返回:無返回值,訓練過程中保存最佳模型到指定路徑。"""# todo 1、準備數據train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、準備模型# 2.1初始化bert分類模型model = BertClassifier()# 2.2將模型移動到指定的設備model.to(conf.device)# todo 3.準備損失函數loss_fn = nn.CrossEntropyLoss()# todo 4.準備優化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)# todo 5.開始訓練模型# 初始化F1分數,用于保存最好的模型best_f1 = 0.0# todo 5.1 外層循環遍歷每個訓練輪次#  (每次需要設置訓練模式,累計損失,預存訓練集測試和真實標簽)for epoch in range(conf.num_epochs):# 設置模型為訓練模式model.train()# 初始化累計損失,初始化訓練集預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 內層循環遍歷訓練DataLoader每個批次for i, batch in enumerate(tqdm(train_dataloader, desc="訓練集訓練中...")):# 提取批次數據并移動到設備input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向傳播:模型預測outputs = model(input_ids, attention_mask)# todo 計算損失loss = loss_fn(outputs, labels)# todo 累計損失total_loss += loss.item()# todo 獲取預測結果(最大logits對應的類別)y_pred_list = torch.argmax(outputs, dim=1)# todo 存儲預測和真實標簽,用于計算訓練集指標train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向傳播:計算梯度loss.backward()# todo 參數更新:根據梯度更新模型參數optimizer.step()# todo 每10個批次或一個輪次結束,計算訓練集指標if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 計算準確率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 獲取batch_count,并計算平均損失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印訓練信息print(f"\n輪次: {epoch + 1}, 批次: {i + 1}, 損失: {avg_loss:.4f}, acc準確率:{acc:.4f}, f1分數:{f1:.4f}")# todo 清空累計損失和預測和真實標簽total_loss = 0.0train_preds, train_labels = [], []# todo 每100個批次或一個輪次結束,計算驗證集指標,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 計算在測試集的評估報告,準確率,精確率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("驗證集評估報告:\n", report)print(f"驗證集f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# 將模型設置為訓練模式model.train()# todo 如果驗證F1分數優于歷史最佳,保存模型if f1score > best_f1:# 更新歷史最佳F1分數best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 當前f1分數:", best_f1)if __name__ == '__main__':model2train()

?7、基于BERT的文本分類預測接口

  1. 加載預訓練好的BERT分類模型(BertClassifier)。

  2. 對輸入文本進行分詞、編碼、轉換為張量。

  3. 使用模型預測文本類別。

  4. 返回類別名稱(如?"education"

import torch
from a2_bert_classifer_model import BertClassifier
from config import Config# 加載配置
conf = Config()# todo 準備模型
model = BertClassifier()
# todo 加載模型參數
model.load_state_dict(torch.load(conf.model_save_path))
# todo 添加模型到指定設備
model.to(conf.device)
# todo 設置模型為評估模式
model.eval()# TODO 定義predict_fun函數預測函數
def predict_fun(data_dict):"""根據用戶錄入數據,返回分類信息:param 參數 data_dict: {"text":"狀元心經:考前一周重點是回顧和整理"}:return: 返回 data_dict: {"text":"狀元心經:考前一周重點是回顧和整理", "pred_class":"education"}"""# todo 獲取文本text = data_dict['text']# todo 將文本轉為idtext_tokens = conf.tokenizer.batch_encode_plus([text],padding="max_length",max_length=conf.pad_size,pad_to_max_length=True)# todo 獲取input_ids和attention_maskinput_ids = text_tokens['input_ids']attention_mask = text_tokens['attention_mask']# todo 將input_ids和attention_mask轉為tensor, 并指定到設備input_ids = torch.tensor(input_ids).to(conf.device)attention_mask = torch.tensor(attention_mask).to(conf.device)# todo 設置不進行梯度計算(在該上下文中禁用梯度計算,提升推理速度并減少內存占用)with torch.no_grad():# 前向傳播(模型預測)output = model(input_ids, attention_mask)# 獲取預測類別索引張量output = torch.argmax(output, dim=1)# 獲取預測類別索引標量pred_idx = output.item()# 獲取類別名稱pred_class = conf.class_list[pred_idx]print(pred_class)# 將預測結果添加到data_dict中data_dict['pred_class'] = pred_class# 返回data_dictreturn data_dictif __name__ == '__main__':data_dict = {'text': '狀元心經:考前一周重點是回顧和整理'}print(predict_fun(data_dict))

問題總結

持續更新中......

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90765.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90765.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90765.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

sendfile系統調用及示例

好的,我們繼續學習 Linux 系統編程中的重要函數。這次我們介紹 sendfile 函數,它是一個高效的系統調用,用于在兩個文件描述符之間直接傳輸數據,通常用于將文件內容發送到網絡套接字,而無需將數據從內核空間復制到用戶空…

數據結構習題--刪除排序數組中的重復項

數據結構習題–刪除排序數組中的重復項 給你一個 非嚴格遞增排列 的數組 nums ,請你 原地 刪除重復出現的元素,使每個元素 只出現一次 ,返回刪除后數組的新長度。元素的 相對順序 應該保持 一致 。然后返回 nums 中唯一元素的個數。 方法&…

Docker的容器設置隨Docker的啟動而啟動

原因也比較簡單,在docker run 的時候沒有設置–restartalways參數。 容器啟動時,需要增加參數 –restartalways no - 容器退出時,不重啟容器; on-failure - 只有在非0狀態退出時才從新啟動容器; always - 無論退出狀態…

JWT安全機制與最佳實踐詳解

JWT(JSON Web Token) 是一種開放標準(RFC 7519),用于在各方之間安全地傳輸信息作為緊湊且自包含的 JSON 對象。它被廣泛用于身份驗證(Authentication)和授權(Authorization&#xff…

如何解決pip安裝報錯ModuleNotFoundError: No module named ‘ipython’問題

【Python系列Bug修復PyCharm控制臺pip install報錯】如何解決pip安裝報錯ModuleNotFoundError: No module named ‘ipython’問題 摘要 在開發過程中,我們常常會遇到pip install報錯的問題,其中一個常見的報錯是 ModuleNotFoundError: No module named…

從三維Coulomb勢到二維對數勢的下降法推導

題目 問題 7. 應用 9.1.4 小節描述的下降法,但針對二維的拉普拉斯方程,并從三維的 Coulomb 勢出發 KaTeX parse error: Invalid delimiter: {"type":"ordgroup","mode":"math","loc":{"lexer&qu…

直播一體機技術方案解析:基于RK3588S的硬件架構特性?

硬件配置??主控平臺??? 搭載瑞芯微RK3588S旗艦處理器(四核A762.4GHz 四核A55)? 集成ARM Mali-G610 MP4 GPU 6TOPS算力NPU? 雙通道LPDDR5內存 UFS3.1存儲組合??專用加速單元??→ 板載視頻采集模塊:支持4K60fps HDMI環出采集→ 集…

【氮化鎵】GaN取代GaAs作為空間激光無線能量傳輸光伏轉換器材料

2025年7月1日,西班牙圣地亞哥-德孔波斯特拉大學的Javier F. Lozano等人在《Optics and Laser Technology》期刊發表了題為《Gallium nitride: a strong candidate to replace GaAs as base material for optical photovoltaic converters in space exploration》的文章,基于T…

直播美顏SDK動態貼紙模塊開發指南:從人臉關鍵點識別到3D貼合

很多美顏技術開發者好奇,如何在直播美顏SDK中實現一個高質量的動態貼紙模塊?這不是簡單地“貼圖貼臉”,而是一個融合人臉關鍵點識別、實時渲染、貼紙驅動邏輯、3D骨骼動畫與跨平臺性能優化的系統工程。今天,就讓我們從底層技術出發…

學習游戲制作記錄(劍投擲技能)7.26

1.實現瞄準狀態和接劍狀態準備好瞄準動畫,投擲動畫和接劍動畫,并設置參數AimSword和CatchSword投擲動畫在瞄準動畫后,瞄準結束后才能投擲創建PlayerAimSwordState腳本和PlayerCatchSwordState腳本并在Player中初始化:PlayerAimSwo…

【c++】問答系統代碼改進解析:新增日志系統提升可維護性——關于我用AI編寫了一個聊天機器人……(14)

在軟件開發中,代碼的迭代優化往往從提升可維護性、可追蹤性入手。本文將詳細解析新增的日志系統改進,以及這些改進如何提升系統的實用性和可調試性。一、代碼整體背景代碼實現了一個基于 TF-IDF 算法的問答系統,核心功能包括:加載…

visual studio2022編譯unreal engine5.4.4源碼

UE5系列文章目錄 文章目錄 UE5系列文章目錄 前言 一、ue5官網 二.編譯源碼中遇到的問題 前言 一、ue5官網 UE5官網 UE5源碼下載地址 這樣雖然下載比較快,但是不能進行代碼git管理,以后如何虛幻官方有大的版本變動需要重新下載源碼,所以我們還是最好需要visual studio2022…

vulhub Earth靶場攻略

靶場下載 下載鏈接:https://download.vulnhub.com/theplanets/Earth.ova 靶場使用 將壓縮包解壓到一個文件夾中,右鍵,用虛擬機打開,就創建成功了,然后啟動虛擬機: 這時候靶場已經啟動了,咱們現…

Python訓練Day24

浙大疏錦行 元組可迭代對象os模塊

Spring核心:Bean生命周期、外部化配置與組件掃描深度解析

Bean生命周期 說明 程序中的每個對象都有生命周期,對象的創建、初始化、應用、銷毀的整個過程稱之為對象的生命周期; 在對象創建以后需要初始化,應用完成以后需要銷毀時執行的一些方法,可以稱之為是生命周期方法; 在sp…

日語學習-日語知識點小記-進階-JLPT-真題訓練-N1階段(1):2017年12月-JLPT-N1

日語學習-日語知識點小記-進階-JLPT-真題訓練-N1階段(1):2017年12月-JLPT-N1 1、前言(1)情況說明(2)工程師的信仰(3)真題訓練2、真題-2017年12月-JLPT-N1(1&a…

(一)使用 LangChain 從零開始構建 RAG 系統|RAG From Scratch

RAG 的主要動機 大模型訓練的時候雖然使用了龐大的世界數據,但是并沒有涵蓋用戶關心的所有數據, 其預訓練令牌(token)數量雖大但相對這些數據仍有限。另外大模型輸入的上下文窗口越來越大,從幾千個token到幾萬個token,…

OpenCV學習探秘之一 :了解opencv技術及架構解析、數據結構與內存管理?等基礎

?一、OpenCV概述與技術演進? 1.1技術歷史? OpenCV(Open Source Computer Vision Library)是由Intel于1999年發起創建的開源計算機視覺庫,后來交由OpenCV開源社區維護,旨在為計算機視覺應用提供通用基礎設施。經歷20余年發展&…

什么是JUC

摘要 Java并發工具包JUC是JDK5.0引入的重要并發編程工具,提供了更高級、靈活的并發控制機制。JUC包含鎖與同步器(如ReentrantLock、Semaphore等)、線程安全隊列(BlockingQueue)、原子變量(AtomicInteger等…

零基礎學后端-PHP語言(第二期-PHP基礎語法)(通過php內置服務器運行php文件)

經過上期的配置,我們已經有了php的開發環境,編輯器我們繼續使用VScode,如果是新來的朋友可以看這期文章來配置VScode 零基礎學前端-傳統前端開發(第一期-開發軟件介紹與本系列目標)(VScode安裝教程&#x…