NLP - 基于bert預訓練模型的文本多分類示例

項目說明

項目名稱

基于DistilBERT的標題多分類任務

項目概述

本項目旨在使用DistilBERT模型對給定的標題文本進行多分類任務。項目包括從數據處理、模型訓練、模型評估到最終的API部署。該項目采用模塊化設計,以便于理解和維護。

項目結構

.
├── bert_data
│   ├── train.txt
│   ├── dev.txt
│   └── test.txt
├── saved_model
├── results
├── logs
├── data_processing.py
├── dataset.py
├── training.py
├── app.py
└── main.py

文件說明

  1. bert_data/:存放訓練集、驗證集和測試集的數據文件。

    • train.txt
    • dev.txt
    • test.txt
  2. saved_model/:存放訓練好的模型和tokenizer。

  3. results/:存放訓練結果。

  4. logs/:存放訓練日志。

  5. data_processing.py:數據處理模塊,負責讀取和預處理數據。

  6. dataset.py:數據集類模塊,定義了用于訓練和評估的數據集類。

  7. training.py:模型訓練模塊,定義了訓練和評估模型的過程。

  8. app.py:模型部署模塊,使用FastAPI創建API服務。

  9. main.py:主腳本,運行整個流程,包括數據處理、模型訓練和部署。

數據集數據規范

為了確保數據處理和模型訓練的順利進行,請按照以下規范準備數據集文件。每個文件包含的標題和標簽分別使用制表符(\t)分隔。以下是一個示例數據集的格式。

數據文件格式

數據文件應為純文本文件,擴展名為.txt,文件內容的每一行應包含一個文本標題和一個對應的分類標簽,用制表符分隔。數據文件不應包含表頭。

數據示例
探索神秘的海底世界    7
如何在家中制作美味披薩    2
全球氣候變化的原因和影響    1
最新的智能手機評測    8
健康飲食:如何搭配均衡的膳食    5
最受歡迎的電影和電視劇推薦    3
了解宇宙的奧秘:天文學入門    0
如何種植和照顧多肉植物    9
時尚潮流:今年夏天的必備單品    6
如何有效管理個人財務    4

注意事項

  • 標簽規范:確保每個標題文本的標簽是一個整數,表示類別。
  • 文本編碼:確保數據文件使用UTF-8編碼,避免中文字符亂碼。
  • 數據一致性:確保訓練、驗證和測試數據格式一致,便于數據加載和處理。

通過以上規范和示例數據文件創建方法,可以確保數據文件符合項目需求,并順利進行數據處理和模型訓練。

模塊說明

1. 數據處理模塊 (data_processing.py)

功能:讀取數據文件并進行預處理。

  • load_data(file_path): 讀取指定路徑的數據文件,并返回一個包含文本和標簽的數據框。
  • tokenize_data(data, tokenizer, max_length=128): 使用BERT的tokenizer對數據進行tokenize處理。
  • main(): 加載數據、tokenize數據并返回處理后的數據。
2. 數據集類模塊 (dataset.py)

功能:定義數據集類,便于模型訓練。

  • TextDataset: 將tokenized數據和標簽封裝成PyTorch的數據集格式,便于Trainer進行訓練和評估。
3. 模型訓練模塊 (training.py)

功能:定義訓練和評估模型的過程。

  • train_model(): 加載數據和tokenizer,創建數據集,加載模型,設置訓練參數,定義Trainer,訓練和評估模型,保存訓練好的模型和tokenizer。
4. 模型部署模塊 (app.py)

功能:使用FastAPI進行模型部署。

  • predict(item: Item): 接收POST請求的文本輸入,使用訓練好的模型進行預測并返回分類結果。
  • FastAPI應用啟動配置。
5. 主腳本 (main.py)

功能:運行整個流程,包括數據處理、模型訓練和部署。

  • main(): 運行模型訓練流程,并輸出訓練完成的提示。

運行步驟

  1. 安裝依賴
pip install pandas torch transformers fastapi uvicorn scikit-learn
  1. 數據處理

確保bert_data文件夾下包含train.txtdev.txttest.txt文件,每個文件包含文本和標簽,使用制表符分隔。

  1. 訓練模型

運行main.py腳本,進行數據處理和模型訓練:

python main.py

訓練完成后,模型和tokenizer將保存在saved_model文件夾中。

  1. 部署模型

運行app.py腳本,啟動API服務:

uvicorn app:app --reload

服務啟動后,可以通過POST請求訪問預測接口,進行文本分類預測。

示例請求

curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{"text": "你的文本"}'

返回示例:

{"prediction": 3
}

注意事項

  • 確保數據文件格式正確,每行包含一個文本和對應的標簽,使用制表符分隔。
  • 調整訓練參數(如batch size和訓練輪數)以適應不同的GPU配置。
  • 使用nvidia-smi監控顯存使用,避免顯存溢出。

項目代碼

1. 數據處理模塊

功能:讀取數據文件并進行預處理。

# data_processing.py
import pandas as pd
from transformers import DistilBertTokenizerdef load_data(file_path):data = pd.read_csv(file_path, delimiter='\t', header=None)data.columns = ['text', 'label']return datadef tokenize_data(data, tokenizer, max_length=128):encodings = tokenizer(list(data['text']), truncation=True, padding=True, max_length=max_length)return encodingsdef main():# 加載Tokenizertokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-chinese')# 加載數據train_data = load_data('./bert_data/train.txt')dev_data = load_data('./bert_data/dev.txt')test_data = load_data('./bert_data/test.txt')# Tokenize數據train_encodings = tokenize_data(train_data, tokenizer)dev_encodings = tokenize_data(dev_data, tokenizer)test_encodings = tokenize_data(test_data, tokenizer)return train_encodings, dev_encodings, test_encodings, train_data['label'], dev_data['label'], test_data['label']if __name__ == "__main__":main()

2. 數據集類模塊

功能:定義數據集類,便于模型訓練。

# dataset.py
import torchclass TextDataset(torch.utils.data.Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)

3. 模型訓練模塊

功能:定義訓練和評估模型的過程。

# training.py
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from dataset import TextDataset
import data_processingdef train_model():# 加載數據和tokenizertrain_encodings, dev_encodings, test_encodings, train_labels, dev_labels, test_labels = data_processing.main()# 創建數據集train_dataset = TextDataset(train_encodings, train_labels)dev_dataset = TextDataset(dev_encodings, dev_labels)test_dataset = TextDataset(test_encodings, test_labels)# 加載DistilBERT模型model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-chinese', num_labels=10)model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))# 設置訓練參數training_args = TrainingArguments(output_dir='./results',          # 輸出結果目錄num_train_epochs=3,              # 訓練輪數per_device_train_batch_size=16,  # 訓練時每個設備的批量大小per_device_eval_batch_size=64,   # 驗證時每個設備的批量大小warmup_steps=500,                # 訓練步數weight_decay=0.01,               # 權重衰減logging_dir='./logs',            # 日志目錄fp16=True,                       # 啟用混合精度訓練)# 定義Trainertrainer = Trainer(model=model,                         # 預訓練模型args=training_args,                  # 訓練參數train_dataset=train_dataset,         # 訓練數據集eval_dataset=dev_dataset             # 驗證數據集)# 訓練模型trainer.train()# 評估模型eval_results = trainer.evaluate()print(eval_results)# 保存模型model.save_pretrained('./saved_model')tokenizer.save_pretrained('./saved_model')if __name__ == "__main__":train_model()

4. 模型部署模塊

功能:使用FastAPI進行模型部署。

# app.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torchapp = FastAPI()# 加載模型和tokenizer
model = DistilBertForSequenceClassification.from_pretrained('./saved_model')
tokenizer = DistilBertTokenizer.from_pretrained('./saved_model')class Item(BaseModel):text: str@app.post("/predict")
def predict(item: Item):inputs = tokenizer(item.text, return_tensors="pt", max_length=128, padding='max_length', truncation=True)outputs = model(**inputs)prediction = torch.argmax(outputs.logits, dim=1)return {"prediction": prediction.item()}if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)

5. 主腳本

功能:運行整個流程,包括數據處理、模型訓練和部署。

# main.py
import trainingdef main():# 訓練模型training.train_model()print("模型訓練完成并保存。")if __name__ == "__main__":main()

詳細說明

  1. 數據處理模塊

    • 讀取訓練集、驗證集和測試集的數據文件。
    • 使用BERT的Tokenizer對數據進行tokenize處理,生成模型可接受的輸入格式。
    • 提供主要的數據處理函數,包括加載數據和tokenize數據。
  2. 數據集類模塊

    • 定義一個TextDataset類,用于將tokenized數據和標簽封裝成PyTorch的數據集格式,便于Trainer進行訓練和評估。
  3. 模型訓練模塊

    • 使用數據處理模塊加載和tokenize數據。
    • 創建訓練和驗證數據集。
    • 加載DistilBERT模型,并設置訓練參數(包括啟用混合精度訓練)。
    • 使用Trainer進行模型訓練和評估,并保存訓練好的模型。
  4. 模型部署模塊

    • 使用FastAPI創建一個簡單的API服務。
    • 加載保存的模型和tokenizer。
    • 定義一個預測接口,通過POST請求接收文本輸入并返回分類預測結果。
  5. 主腳本

    • 運行模型訓練流程,并輸出訓練完成的提示。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/40200.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/40200.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/40200.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

蘋果AI的國產大模型之爭,沒有懸念

文 | 智能相對論 作者 | 陳泊丞 蘋果終于公布了最新的AI進程。 一個月前,正如此前預期的那樣,人工智能是今年 WWDC 發布會的焦點。全程105分鐘的主題演講,就有40多分鐘用于介紹蘋果的AI成果。 蘋果似乎還有意玩了一把“諧音梗”&#xff…

用機器改變人類方向

1800 世紀初,美國迎來了工業革命,這是一個由技術進步推動的變革時代。新機器和制造技術的引入重塑了經濟格局,提高了生產效率,同時減少了某些領域對手工勞動的需求。因此,這種轉變導致了失業。 如今,我們看…

實現點擊按鈕導出頁面pdf

在Vue 3 Vite項目中,你可以使用html2canvas和jspdf庫來實現將頁面某部分導出為PDF文檔的功能。以下是一個簡單的實現方式: 1.安裝html2canvas和jspdf: pnpm install html2canvas jspdf 2.在Vue組件中使用這些庫來實現導出功能:…

統計信號處理基礎 習題解答11-11

題目 考慮矢量MAP估計量 證明這個估計量對于代價函數 使貝葉斯風險最小。其中:, ,且. 解答 貝葉斯風險函數: 基于概率密度的非負特性,上述對積分要求最小,那就需要內層積分達到最小。令內層積分為: 上述積…

蘋果Mac電腦能玩什么游戲 Mac怎么運行Windows游戲

相對于Windows平臺來說,Mac電腦可玩的游戲較少。雖然蘋果設備的性能足以支持各種大型游戲,但由于系統以及蘋果配套服務的限制,很多游戲無法在Mac系統中運行。不過,借助虛擬機軟件,Mac電腦可以突破系統限制玩更多的游戲…

react中jsx的語法規則

1.react核心庫react.development.js 2.react_dom庫,用于支持react操作dom(react-dom.development.js) 3.引入bable,解析jsx語法的庫,用于將jsx轉換為js(babel.min.js) 上述三個庫是寫基礎react的基本庫 下面我將用…

光照老化試驗箱在化工產品暴曬測試中的應用

概述 光照老化試驗箱是一種模擬自然光照條件下材料老化情況的實驗設備,廣泛應用于化工、建材、電子、汽車等行業中對材料的耐候性、耐光性能等進行測試。通過模擬日光中的紫外線和溫度等環境因素,加速材料老化過程,以此評估材料在長期使用中…

2024阿里云大模型自定義插件(如何調用自定義接口)

1,自定義插件入口 2,插件定義:描述插件的參數 2.1,注意事項: 2.1.1,只支持json格式的參數;只支持application/JSON;如下圖: 2.1.2,需要把接口描述進行修改&a…

03:Spring MVC

文章目錄 一:Spring MVC簡介1:說說自己對于Spring MVC的了解?1.1:流程說明: 一:Spring MVC簡介 Spring MVC就是一個MVC框架,Spring MVC annotation式的開發比Struts2方便,可以直接代…

LeetCode 算法:二叉搜索樹中第K小的元素 c++

原題鏈接🔗:二叉搜索樹中第K小的元素 難度:中等???? 題目 給定一個二叉搜索樹的根節點 root ,和一個整數 k ,請你設計一個算法查找其中第 k 小的元素(從1開始計數)。 示例 1:…

網絡爬蟲之什么是代碼混淆?初步理解代碼混淆

爬蟲逆向之什么是代碼混淆?初步理解代碼混淆 在網絡爬蟲和逆向工程的過程中,代碼混淆是一項常見的技術,旨在保護代碼不被輕易理解和逆向。對于爬蟲工程師來說,理解并破解代碼混淆是一個重要的技能。本文將詳細介紹代碼混淆的基本概…

GUI開發

Question One Java 實現動作監聽,網格布局添加四個按鈕,實現四個不同的文本顯示 import java.awt.*; import java.awt.event.*; import javax.swing.*;class myGUI extends JFrame implements ActionListener{private Button b1, b2, b3, b4;private Tex…

0627,0628,0629,排序,文件

01:請實現選擇排序,并分析它的時間復雜度,空間復雜度和穩定性 void selection_sort(int arr[], int n); 解答: 穩定性:穩定, 不穩定的,會發生長距離的交換 4 9 9 4 1 &#xf…

ubuntu,linux下屏蔽壞塊方法-240625-240702封存

在windows下的屏蔽壞道的方法 機械硬盤壞道的文件系統級別的屏蔽方法_硬盤如何屏蔽壞扇區-CSDN博客 https://blog.csdn.net/cyuyan112233/article/details/139408503?spm1001.2014.3001.5502 【免費】磁盤壞道屏蔽工具磁盤壞道屏蔽工具_機械硬盤屏蔽壞扇區資源-CSDN文庫 https…

第一周題目總結

1.車爾尼有一個數組 nums ,它只包含 正 整數,所有正整數的數位長度都 相同 。 兩個整數的 數位不同 指的是兩個整數 相同 位置上不同數字的數目。 請車爾尼返回 nums 中 所有 整數對里,數位不同之和。 示例 1: 輸入&#xff1a…

【嵌入式DIY實例-ESP8266篇】-LCD ST7735顯示網絡時間

LCD ST7735顯示網絡時間 文章目錄 LCD ST7735顯示網絡時間1、硬件準備2、代碼實現本文將介紹如何使用 ESP8266 NodeMCU Wi-Fi 板實現互聯網時鐘,其中時間和日期顯示在 ST7735 TFT 顯示屏上。 ST7735 TFT是一款分辨率為128160像素的彩色顯示屏,采用SPI協議與主控設備通信。 1…

Python中的變量和數據類型:Python中有哪些基本數據類型以及變量是如何聲明的

在Python中,變量是用來存儲數據的容器,而數據類型則定義了這些數據的種類。Python是一種動態類型語言,這意味著你不需要在聲明變量時指定其類型;Python解釋器會在運行時自動確定變量的類型。 Python中的基本數據類型 Python中有…

SQL語句(DML)

DML英文全稱是Data Manipulation Language(數據操作語言),用來對數據庫中表的數據記錄進行增刪改等操作 DML-添加數據 insert into employee(id, workno, name, gender, age, idcard) values (1,1,Itcast,男,10,123456789012345678);select *…

AI 與數據的智能融合丨大模型時代下的存儲系統

WOT 全球技術創新大會2024北京站于 6 月 22 日圓滿落幕。本屆大會以“智啟新紀,慧創萬物”為主題,邀請到 60 位不同行業的專家,聚焦 AIGC、領導力、研發效能、架構演進、大數據等熱門技術話題進行分享。 近年來,數據和人工智能已…

記錄搭建一臺可域名訪問的HTTPS服務器

一、背景 近期公司業務涉及到微信小程序,即將開發完成需要按照微信小程序平臺的要求提供帶證書的域名請求服務器。 資源背景介紹如下: 1、域名 公司已有一個二級域名,再次申請新的二級域名并且實現ICP備案不僅需要花重金重新購買,…