項目說明
項目名稱
基于DistilBERT的標題多分類任務
項目概述
本項目旨在使用DistilBERT模型對給定的標題文本進行多分類任務。項目包括從數據處理、模型訓練、模型評估到最終的API部署。該項目采用模塊化設計,以便于理解和維護。
項目結構
.
├── bert_data
│ ├── train.txt
│ ├── dev.txt
│ └── test.txt
├── saved_model
├── results
├── logs
├── data_processing.py
├── dataset.py
├── training.py
├── app.py
└── main.py
文件說明
-
bert_data/:存放訓練集、驗證集和測試集的數據文件。
- train.txt
- dev.txt
- test.txt
-
saved_model/:存放訓練好的模型和tokenizer。
-
results/:存放訓練結果。
-
logs/:存放訓練日志。
-
data_processing.py:數據處理模塊,負責讀取和預處理數據。
-
dataset.py:數據集類模塊,定義了用于訓練和評估的數據集類。
-
training.py:模型訓練模塊,定義了訓練和評估模型的過程。
-
app.py:模型部署模塊,使用FastAPI創建API服務。
-
main.py:主腳本,運行整個流程,包括數據處理、模型訓練和部署。
數據集數據規范
為了確保數據處理和模型訓練的順利進行,請按照以下規范準備數據集文件。每個文件包含的標題和標簽分別使用制表符(\t
)分隔。以下是一個示例數據集的格式。
數據文件格式
數據文件應為純文本文件,擴展名為.txt
,文件內容的每一行應包含一個文本標題和一個對應的分類標簽,用制表符分隔。數據文件不應包含表頭。
數據示例
探索神秘的海底世界 7
如何在家中制作美味披薩 2
全球氣候變化的原因和影響 1
最新的智能手機評測 8
健康飲食:如何搭配均衡的膳食 5
最受歡迎的電影和電視劇推薦 3
了解宇宙的奧秘:天文學入門 0
如何種植和照顧多肉植物 9
時尚潮流:今年夏天的必備單品 6
如何有效管理個人財務 4
注意事項
- 標簽規范:確保每個標題文本的標簽是一個整數,表示類別。
- 文本編碼:確保數據文件使用UTF-8編碼,避免中文字符亂碼。
- 數據一致性:確保訓練、驗證和測試數據格式一致,便于數據加載和處理。
通過以上規范和示例數據文件創建方法,可以確保數據文件符合項目需求,并順利進行數據處理和模型訓練。
模塊說明
1. 數據處理模塊 (data_processing.py)
功能:讀取數據文件并進行預處理。
load_data(file_path)
: 讀取指定路徑的數據文件,并返回一個包含文本和標簽的數據框。tokenize_data(data, tokenizer, max_length=128)
: 使用BERT的tokenizer對數據進行tokenize處理。main()
: 加載數據、tokenize數據并返回處理后的數據。
2. 數據集類模塊 (dataset.py)
功能:定義數據集類,便于模型訓練。
TextDataset
: 將tokenized數據和標簽封裝成PyTorch的數據集格式,便于Trainer進行訓練和評估。
3. 模型訓練模塊 (training.py)
功能:定義訓練和評估模型的過程。
train_model()
: 加載數據和tokenizer,創建數據集,加載模型,設置訓練參數,定義Trainer,訓練和評估模型,保存訓練好的模型和tokenizer。
4. 模型部署模塊 (app.py)
功能:使用FastAPI進行模型部署。
predict(item: Item)
: 接收POST請求的文本輸入,使用訓練好的模型進行預測并返回分類結果。- FastAPI應用啟動配置。
5. 主腳本 (main.py)
功能:運行整個流程,包括數據處理、模型訓練和部署。
main()
: 運行模型訓練流程,并輸出訓練完成的提示。
運行步驟
- 安裝依賴
pip install pandas torch transformers fastapi uvicorn scikit-learn
- 數據處理
確保bert_data
文件夾下包含train.txt
、dev.txt
和test.txt
文件,每個文件包含文本和標簽,使用制表符分隔。
- 訓練模型
運行main.py
腳本,進行數據處理和模型訓練:
python main.py
訓練完成后,模型和tokenizer將保存在saved_model
文件夾中。
- 部署模型
運行app.py
腳本,啟動API服務:
uvicorn app:app --reload
服務啟動后,可以通過POST請求訪問預測接口,進行文本分類預測。
示例請求
curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{"text": "你的文本"}'
返回示例:
{"prediction": 3
}
注意事項
- 確保數據文件格式正確,每行包含一個文本和對應的標簽,使用制表符分隔。
- 調整訓練參數(如batch size和訓練輪數)以適應不同的GPU配置。
- 使用
nvidia-smi
監控顯存使用,避免顯存溢出。
項目代碼
1. 數據處理模塊
功能:讀取數據文件并進行預處理。
# data_processing.py
import pandas as pd
from transformers import DistilBertTokenizerdef load_data(file_path):data = pd.read_csv(file_path, delimiter='\t', header=None)data.columns = ['text', 'label']return datadef tokenize_data(data, tokenizer, max_length=128):encodings = tokenizer(list(data['text']), truncation=True, padding=True, max_length=max_length)return encodingsdef main():# 加載Tokenizertokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-chinese')# 加載數據train_data = load_data('./bert_data/train.txt')dev_data = load_data('./bert_data/dev.txt')test_data = load_data('./bert_data/test.txt')# Tokenize數據train_encodings = tokenize_data(train_data, tokenizer)dev_encodings = tokenize_data(dev_data, tokenizer)test_encodings = tokenize_data(test_data, tokenizer)return train_encodings, dev_encodings, test_encodings, train_data['label'], dev_data['label'], test_data['label']if __name__ == "__main__":main()
2. 數據集類模塊
功能:定義數據集類,便于模型訓練。
# dataset.py
import torchclass TextDataset(torch.utils.data.Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)
3. 模型訓練模塊
功能:定義訓練和評估模型的過程。
# training.py
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from dataset import TextDataset
import data_processingdef train_model():# 加載數據和tokenizertrain_encodings, dev_encodings, test_encodings, train_labels, dev_labels, test_labels = data_processing.main()# 創建數據集train_dataset = TextDataset(train_encodings, train_labels)dev_dataset = TextDataset(dev_encodings, dev_labels)test_dataset = TextDataset(test_encodings, test_labels)# 加載DistilBERT模型model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-chinese', num_labels=10)model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))# 設置訓練參數training_args = TrainingArguments(output_dir='./results', # 輸出結果目錄num_train_epochs=3, # 訓練輪數per_device_train_batch_size=16, # 訓練時每個設備的批量大小per_device_eval_batch_size=64, # 驗證時每個設備的批量大小warmup_steps=500, # 訓練步數weight_decay=0.01, # 權重衰減logging_dir='./logs', # 日志目錄fp16=True, # 啟用混合精度訓練)# 定義Trainertrainer = Trainer(model=model, # 預訓練模型args=training_args, # 訓練參數train_dataset=train_dataset, # 訓練數據集eval_dataset=dev_dataset # 驗證數據集)# 訓練模型trainer.train()# 評估模型eval_results = trainer.evaluate()print(eval_results)# 保存模型model.save_pretrained('./saved_model')tokenizer.save_pretrained('./saved_model')if __name__ == "__main__":train_model()
4. 模型部署模塊
功能:使用FastAPI進行模型部署。
# app.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torchapp = FastAPI()# 加載模型和tokenizer
model = DistilBertForSequenceClassification.from_pretrained('./saved_model')
tokenizer = DistilBertTokenizer.from_pretrained('./saved_model')class Item(BaseModel):text: str@app.post("/predict")
def predict(item: Item):inputs = tokenizer(item.text, return_tensors="pt", max_length=128, padding='max_length', truncation=True)outputs = model(**inputs)prediction = torch.argmax(outputs.logits, dim=1)return {"prediction": prediction.item()}if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)
5. 主腳本
功能:運行整個流程,包括數據處理、模型訓練和部署。
# main.py
import trainingdef main():# 訓練模型training.train_model()print("模型訓練完成并保存。")if __name__ == "__main__":main()
詳細說明
-
數據處理模塊:
- 讀取訓練集、驗證集和測試集的數據文件。
- 使用BERT的Tokenizer對數據進行tokenize處理,生成模型可接受的輸入格式。
- 提供主要的數據處理函數,包括加載數據和tokenize數據。
-
數據集類模塊:
- 定義一個
TextDataset
類,用于將tokenized數據和標簽封裝成PyTorch的數據集格式,便于Trainer進行訓練和評估。
- 定義一個
-
模型訓練模塊:
- 使用數據處理模塊加載和tokenize數據。
- 創建訓練和驗證數據集。
- 加載DistilBERT模型,并設置訓練參數(包括啟用混合精度訓練)。
- 使用
Trainer
進行模型訓練和評估,并保存訓練好的模型。
-
模型部署模塊:
- 使用FastAPI創建一個簡單的API服務。
- 加載保存的模型和tokenizer。
- 定義一個預測接口,通過POST請求接收文本輸入并返回分類預測結果。
-
主腳本:
- 運行模型訓練流程,并輸出訓練完成的提示。