基于Bert模型的增量微調3-使用csv文件訓練

我們使用weibo評價數據,8分類的csv格式數據集。

一、創建數據集合

使用csv格式的數據作為數據集。

1、創建MydataCSV.py

from  torch.utils.data import Dataset
from datasets import load_datasetclass MyDataset(Dataset):#初始化數據集def __init__(self, split):# 加載csv數據self.dataset=load_dataset(path="csv",data_files=f"D:\Test\LLMTrain\day03\data\Weibo/{split}.csv", split= "train")# 返回數據集長度def __len__(self):return len(self.dataset)# 對每條數據單獨進行數據處理def __getitem__(self, idx):text=self.dataset[idx]["text"]label=self.dataset[idx]["label"]return  text,labelif __name__== "__main__":train_dataset=MyDataset("test")for i in range(10):print(train_dataset[i])

二、處理模型

我們使用8分類任務

創建netCSV.py

import torch
from transformers import BertModel#定義設備信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)#加載預訓練模型
path1=r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
pretrained = BertModel.from_pretrained(path1).to(DEVICE)
print(pretrained)#定義下游任務(增量模型)
class Model(torch.nn.Module):def __init__(self):super().__init__()#設計全連接網絡,實現8分類任務self.fc = torch.nn.Linear(768,8)#使用模型處理數據(執行前向計算)def forward(self,input_ids,attention_mask,token_type_ids):#凍結Bert模型的參數,讓其不參與訓練with torch.no_grad():out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)#增量模型參與訓練out = self.fc(out.last_hidden_state[:,0])return out

8分類任務,所以?self.fc=torch.nn.Liner768,8) 。

我們是對大模型做增量微調訓練,所以需要凍結Bert模型的參數,讓其不參與訓練。所以使用?

with torch.no_grad()。

我們定義一個下游任務增量模型Model類,繼承 torch.nn.Module。

三、訓練的代碼

1、創建目錄params

存放訓練后的結果。

2、寫代碼

創建train_val_csv.py

#模型訓練
import torch
from MyDataCSV import MyDataset
from torch.utils.data import DataLoader
from netCSV import Model
from transformers import BertTokenizer,AdamW#定義設備信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定義訓練的輪次(將整個數據集訓練完一次為一輪)
EPOCH = 30000#加載字典和分詞器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")#將傳入的字符串進行編碼
def collate_fn(data):sents = [i[0]for i in data]label = [i[1] for i in data]#編碼data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 當句子長度大于max_length(上限是model_max_length)時,截斷truncation=True,max_length=512,# 一律補0到max_lengthpadding="max_length",# 可取值為tf,pt,np,默認為listreturn_tensors="pt",# 返回序列長度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]label = torch.LongTensor(label)return input_ids,attention_mask,token_type_ids,label#創建數據集
train_dataset = MyDataset("train")
train_loader = DataLoader(dataset=train_dataset,#訓練批次batch_size=50,#打亂數據集shuffle=True,#舍棄最后一個批次的數據,防止形狀出錯drop_last=True,#對加載的數據進行編碼collate_fn=collate_fn
)
#創建驗證數據集
val_dataset = MyDataset("validation")
val_loader = DataLoader(dataset=val_dataset,#訓練批次batch_size=50,#打亂數據集shuffle=True,#舍棄最后一個批次的數據,防止形狀出錯drop_last=True,#對加載的數據進行編碼collate_fn=collate_fn
)
if __name__ == '__main__':#開始訓練print(DEVICE)model = Model().to(DEVICE)#定義優化器optimizer = AdamW(model.parameters())#定義損失函數loss_func = torch.nn.CrossEntropyLoss()#初始化驗證最佳準確率best_val_acc = 0.0for epoch in range(EPOCH):for i,(input_ids,attention_mask,token_type_ids,label) in enumerate(train_loader):#將數據放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),label.to(DEVICE)#前向計算(將數據輸入模型得到輸出)out = model(input_ids,attention_mask,token_type_ids)#根據輸出計算損失loss = loss_func(out,label)#根據誤差優化參數optimizer.zero_grad()loss.backward()optimizer.step()#每隔5個批次輸出訓練信息if i%5 ==0:out = out.argmax(dim=1)#計算訓練精度acc = (out==label).sum().item()/len(label)print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")#驗證模型(判斷模型是否過擬合)#設置為評估模型model.eval()#不需要模型參與訓練with torch.no_grad():val_acc = 0.0val_loss = 0.0for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(val_loader):# 將數據放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), label.to(DEVICE)# 前向計算(將數據輸入模型得到輸出)out = model(input_ids, attention_mask, token_type_ids)# 根據輸出計算損失val_loss += loss_func(out, label)#根據數據,計算驗證精度out = out.argmax(dim=1)val_acc+=(out==label).sum().item()val_loss/=len(val_loader)val_acc/=len(val_loader)print(f"驗證集:loss:{val_loss},acc:{val_acc}")# #每訓練完一輪,保存一次參數# torch.save(model.state_dict(),f"params/{epoch}_bert.pth")# print(epoch,"參數保存成功!")#根據驗證準確率保存最優參數if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(),"params1/best_bert.pth")print(f"EPOCH:{epoch}:保存最優參數:acc{best_val_acc}")#保存最后一輪參數torch.save(model.state_dict(), "params1/last_bert.pth")print(f"EPOCH:{epoch}:最后一輪參數保存成功!")

3、執行代碼

這個過程需等待很久,若是使用cuda環境,顯存越大,速度越快。

train_loader的訓練批次batch_size=50,這個數值是根據電腦的配置來的,數值越大越好,只要不超過顯存或者內存的90%即可。

四、使用訓練好的模型

我們寫一個控制臺程序,也可以使用FastAPI。創建run.py文件。

#模型使用接口(主觀評估)
#模型訓練
import torch
from net import Model
from transformers import BertTokenizer#定義設備信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")#加載字典和分詞器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
model = Model().to(DEVICE)
names = ["負向評價","正向評價"]#將傳入的字符串進行編碼
def collate_fn(data):sents = []sents.append(data)#編碼data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 當句子長度大于max_length(上限是model_max_length)時,截斷truncation=True,max_length=512,# 一律補0到max_lengthpadding="max_length",# 可取值為tf,pt,np,默認為listreturn_tensors="pt",# 返回序列長度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]return input_ids,attention_mask,token_type_idsdef test():#加載模型訓練參數model.load_state_dict(torch.load("params/best_bert.pth"))#開啟測試模型model.eval()while True:data = input("請輸入測試數據(輸入‘q’退出):")if data=='q':print("測試結束")breakinput_ids,attention_mask,token_type_ids = collate_fn(data)input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE)#將數據輸入到模型,得到輸出with torch.no_grad():out = model(input_ids,attention_mask,token_type_ids)out = out.argmax(dim=1)print("模型判定:",names[out],"\n")if __name__ == '__main__':test()

運行程序 ,輸入test測試集里的數據進行驗證,或許輸入其他的文本驗證。

?正確率還是非常棒的。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/73244.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/73244.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/73244.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

flowable新增或修改單個任務的歷史變量

簡介 場景:對歷史任務進行關注,所以需要修改流程歷史任務的本地變量 方法包含2個類 1)核心方法,flowable command類:HistoricTaskSingleVariableUpdateCmd 2)執行command類:BpmProcessCommandS…

Netty基礎—4.NIO的使用簡介一

大綱 1.Buffer緩沖區 2.Channel通道 3.BIO編程 4.偽異步IO編程 5.改造程序以支持長連接 6.NIO三大核心組件 7.NIO服務端的創建流程 8.NIO客戶端的創建流程 9.NIO優點總結 10.NIO問題總結 1.Buffer緩沖區 (1)Buffer緩沖區的作用 (2)Buffer緩沖區的4個核心概念 (3)使…

python元組(被捆綁的列表)

元組(tuple) 1.元組一旦形成就不可更改,元組所指向的內存單元中內容不變 定義:定義元組使用小括號,并且使用逗號進行隔開,數據可以是不同的數據類型 定義元組自變量(元素,元素,元素…

輸入:0.5元/百萬tokens(緩存命中)或2元(未命中) 輸出:8元/百萬tokens

這句話描述了一種 定價模型,通常用于云計算、API 服務或數據處理服務中,根據資源使用情況(如緩存命中與否)來收費。以下是對這句話的詳細解釋: 1. 關鍵術語解釋 Tokens:在自然語言處理(NLP&…

計算機視覺算法實戰——駕駛員玩手機檢測(主頁有源碼)

?個人主頁歡迎您的訪問 ?期待您的三連 ? ?個人主頁歡迎您的訪問 ?期待您的三連 ? ?個人主頁歡迎您的訪問 ?期待您的三連? ? ??? 1. 領域簡介:玩手機檢測的重要性與技術挑戰 駕駛員玩手機檢測是智能交通安全領域的核心課題。根據NHTSA數據&#xff0…

Java糊涂包(Hutool)的安裝教程并進行網絡爬蟲

Hutool的使用教程 1:在官網下載jar模塊文件 Central Repository: cn/hutool/hutool-all/5.8.26https://repo1.maven.org/maven2/cn/hutool/hutool-all/5.8.26/ 下載后綴只用jar的文件 2:復制并到idea當中,右鍵這個模塊點擊增加到庫 3&…

深度學習項目--基于DenseNet網絡的“乳腺癌圖像識別”,準確率090%+,pytorch復現

🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客🍖 原作者:K同學啊 前言 如果說最經典的神經網絡,ResNet肯定是一個,從ResNet發布后,很多人做了修改,denseNet網絡無疑是最成功的…

優化用戶體驗:關鍵 Web 性能指標的獲取、分析、優化方法

前言 在當今互聯網高速發展的時代用戶對于網頁的加載速度和響應時間越來越敏感。一個性能表現不佳的網頁不僅會影響用戶體驗,還可能導致用戶流失。 因此,了解和優化網頁性能指標是每個開發者的必修課。今天我們就來聊聊常見的網頁性能指標以及如何獲取這…

vs code配置 c/C++

1、下載VSCode Visual Studio Code - Code Editing. Redefined 安裝目錄可改 勾選創建桌面快捷方式 安裝即可 2、漢化VSCode 點擊確定 下載MinGW 由于vsCode 只是一個編輯器,他沒有自帶編譯器,所以需要下載一個編譯器"MinGW". https://…

Kotlin關鍵字`when`的詳細用法

Kotlin關鍵字when的詳細用法 在Kotlin中,when是一個強大的控制流語句,相當于其他語言中的switch語句,但更加強大且靈活。本文將詳細講解when的用法及其常見場景,并與Java的switch語句進行對比。 一、基本語法 基本的when語法如…

MFCday01、模式對話框

對話框類和應用程序類。 MFC中 Combo Box List Box List Control三種列表控件,日期控件Date Time Picker

接口測試筆記

4、接口測試自動化 接口自動化概述 HttpClient HttpClient開發過程 創建Java工程 新建libs庫目錄 HttpClient 工具下載及引入 https://hc.apache.org/index.html工程中引入jar包 Get請求 HttpGet方法---發起Get請求 創建HttpClient對象 CloseableHttpClient httpclient …

查找sql中涉及的表名稱

import pandas as pd import datetime todaystr(datetime.date.today())filepath/Users/kangyongqing/Documents/kangyq/202303/分析模版/sql表引用提取/ file101試聽課明細.txt newfilefile1.title().split(.)[0]with open(filepathfile1,r) as file:contentfile.read().lower…

如何在Ubuntu上構建編譯LLVM和ISPC,以及Ubuntu上ISPC的使用方法

之前一直在 Mac 上使用 ISPC,奈何核心/線程太少了。最近想在 Ubuntu 上搞搞,但是 snap 安裝的 ISPC不知道為什么只能單核,很奇怪,就想著編譯一下,需要 Clang 和 LLVM。但是 Ubuntu 很搞,他的很多軟件版本是…

【Spring IOC/AOP】

IOC 參考: Spring基礎 - Spring核心之控制反轉(IOC) | Java 全棧知識體系 (pdai.tech) 概述: Ioc 即 Inverse of Control (控制反轉),是一種設計思想,就是將原本在程序中手動創建對象的控制權&#xff…

電感與電容的具體應用

文章目錄 一、電感應用1.?電源濾波:2. 儲能——平滑“電流波浪”? ?3. 調諧——校準“頻率樂器”?4. 限流——防止“洪水災害”?二、電容應用1.核心特性理解2.應用場景 三.電容電感對比 一、電感應用 1.?電源濾波: ?場景:工業設備中…

前端面試:axios 請求的底層依賴是什么?

在前端開發中,Axios 是一個流行的 JavaScript 庫,用于發送 HTTP 請求。它簡化了與 RESTful APIs 的交互,并提供了許多便利的方法與配置選項。要理解 Axios 的底層依賴,需要從以下幾個方面進行分析: 1. Axios 基于 XML…

springboot 3 集成Redisson

maven 依賴 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.2.12</version></parent><dependencies><dependency><groupId>org.red…

C#中繼承的核心定義?

1. 繼承的核心定義? ?繼承? 是面向對象編程&#xff08;OOP&#xff09;的核心特性之一&#xff0c;允許一個類&#xff08;稱為?子類/派生類?&#xff09;基于另一個類&#xff08;稱為?父類/基類?&#xff09;構建&#xff0c;自動獲得父類的成員&#xff08;字段、屬…

Deep research深度研究:ChatGPT/ Gemini/ Perplexity/ Grok哪家最強?(實測對比分析)

目前推出深度研究和深度檢索的AI大模型有四家&#xff1a; OpenAI和Gemini 的deep research&#xff0c;以及Perplexity 和Grok的deep search&#xff0c;都能生成帶參考文獻引用的主題報告。 致力于“幾分鐘之內生成一份完整的主題調研報告&#xff0c;解決人力幾小時甚至幾天…