nlp培訓重點-5

1. LoRA微調

loader:

# -*- coding: utf-8 -*-import json
import re
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
"""
數據加載
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.index_to_label = {0: '家居', 1: '房產', 2: '股票', 3: '社會', 4: '文化',5: '國際', 6: '教育', 7: '軍事', 8: '彩票', 9: '旅游',10: '體育', 11: '科技', 12: '汽車', 13: '健康',14: '娛樂', 15: '財經', 16: '時尚', 17: '游戲'}self.label_to_index = dict((y, x) for x, y in self.index_to_label.items())self.config["class_num"] = len(self.index_to_label)if self.config["model_type"] == "bert":self.tokenizer = BertTokenizer.from_pretrained(config["pretrain_model_path"])self.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for line in f:line = json.loads(line)tag = line["tag"]label = self.label_to_index[tag]title = line["title"]if self.config["model_type"] == "bert":input_id = self.tokenizer.encode(title, max_length=self.config["max_length"], pad_to_max_length=True)else:input_id = self.encode_sentence(title)input_id = torch.LongTensor(input_id)label_index = torch.LongTensor([label])self.data.append([input_id, label_index])returndef encode_sentence(self, text):input_id = []for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))input_id = self.padding(input_id)return input_id#補齊或截斷輸入的序列,使其可以在一個batch內運算def padding(self, input_id):input_id = input_id[:self.config["max_length"]]input_id += [0] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = index + 1  #0留給padding位置,所以從1開始return token_dict#用torch自帶的DataLoader類封裝數據
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("valid_tag_news.json", Config)print(dg[1])

model:

import torch.nn as nn
from config import Config
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim import Adam, SGDTorchModel = AutoModelForSequenceClassification.from_pretrained(Config["pretrain_model_path"])def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

evaluate:

# -*- coding: utf-8 -*-
import torch
from loader import load_data"""
模型效果測試
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)self.stats_dict = {"correct":0, "wrong":0}  #用于存儲測試結果def eval(self, epoch):self.logger.info("開始測試第%d輪模型效果:" % epoch)self.model.eval()self.stats_dict = {"correct": 0, "wrong": 0}  # 清空上一輪結果for index, batch_data in enumerate(self.valid_data):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_ids, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況with torch.no_grad():pred_results = self.model(input_ids)[0]self.write_stats(labels, pred_results)acc = self.show_stats()return accdef write_stats(self, labels, pred_results):# assert len(labels) == len(pred_results)for true_label, pred_label in zip(labels, pred_results):pred_label = torch.argmax(pred_label)# print(true_label, pred_label)if int(true_label) == int(pred_label):self.stats_dict["correct"] += 1else:self.stats_dict["wrong"] += 1returndef show_stats(self):correct = self.stats_dict["correct"]wrong = self.stats_dict["wrong"]self.logger.info("預測集合條目總量:%d" % (correct +wrong))self.logger.info("預測正確條目:%d,預測錯誤條目:%d" % (correct, wrong))self.logger.info("預測準確率:%f" % (correct / (correct + wrong)))self.logger.info("--------------------")return correct / (correct + wrong)

?main:

# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import torch.nn as nn
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
from peft import get_peft_model, LoraConfig, \PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig #[DEBUG, INFO, WARNING, ERROR, CRITICAL]
logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型訓練主程序
"""seed = Config["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)def main(config):#創建保存模型的目錄if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加載訓練數據train_data = load_data(config["train_data_path"], config)#加載模型model = TorchModel#大模型微調策略tuning_tactics = config["tuning_tactics"]if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)model = get_peft_model(model, peft_config)# print(model.state_dict().keys())if tuning_tactics == "lora_tuning":# lora配置會凍結原始模型中的所有層的權重,不允許其反傳梯度# 但是事實上我們希望最后一個線性層照常訓練,只是bert部分被凍結,所以需要手動設置for param in model.get_submodule("model").get_submodule("classifier").parameters():param.requires_grad = True# 標識是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,遷移模型至gpu")model = model.cuda()#加載優化器optimizer = choose_optimizer(config, model)#加載效果測試類evaluator = Evaluator(config, model, logger)#訓練for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):if cuda_flag:batch_data = [d.cuda() for d in batch_data]optimizer.zero_grad()input_ids, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況output = model(input_ids)[0]loss = nn.CrossEntropyLoss()(output, labels.view(-1))loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))acc = evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "%s.pth" % tuning_tactics)save_tunable_parameters(model, model_path)  #保存模型權重return accdef save_tunable_parameters(model, path):saved_params = {k: v.to("cpu")for k, v in model.named_parameters()if v.requires_grad}torch.save(saved_params, path)if __name__ == "__main__":main(Config)

pred:

import torch
import logging
from model import TorchModel
from peft import get_peft_model, LoraConfig, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfigfrom evaluate import Evaluator
from config import Configlogging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)#大模型微調策略
tuning_tactics = Config["tuning_tactics"]print("正在使用 %s"%tuning_tactics)if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])
elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)#重建模型
model = TorchModel
# print(model.state_dict().keys())
# print("====================")model = get_peft_model(model, peft_config)
# print(model.state_dict().keys())
# print("====================")state_dict = model.state_dict()#將微調部分權重加載
if tuning_tactics == "lora_tuning":loaded_weight = torch.load('output/lora_tuning.pth')
elif tuning_tactics == "p_tuning":loaded_weight = torch.load('output/p_tuning.pth')
elif tuning_tactics == "prompt_tuning":loaded_weight = torch.load('output/prompt_tuning.pth')
elif tuning_tactics == "prefix_tuning":loaded_weight = torch.load('output/prefix_tuning.pth')print(loaded_weight.keys())
state_dict.update(loaded_weight)#權重更新后重新加載到模型
model.load_state_dict(state_dict)#進行一次測試
model = model.cuda()
evaluator = Evaluator(Config, model, logger)
evaluator.eval(0)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/72909.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/72909.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/72909.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

CI/CD—Jenkins配置Maven+GitLab自動構建jar包

一、安裝Maven插件通過Maven構建項目 1、在Jenkins上安裝Maven Integration plugin插件 2、創建一個maven項目 2.1、填寫構建的名稱和描述等 2.2、填寫連接git的url 報錯:無法連接倉庫:Error performing git command: git ls-remote -h http://192.168.…

ngx_regex_create_conf

ngx_regex_create_conf 定義在 src/core/ngx_regex.c static void * ngx_regex_create_conf(ngx_cycle_t *cycle) {ngx_regex_conf_t *rcf;ngx_pool_cleanup_t *cln;rcf ngx_pcalloc(cycle->pool, sizeof(ngx_regex_conf_t));if (rcf NULL) {return NULL;}rcf->p…

【數據結構】初識集合框架及背后的數據結構(簡單了解)

目錄 前言 如何學好數據結構 1. 什么是集合框架 2. 集合框架的重要性 3. 背后所涉及的數據結構以及算法 3.1 什么是數據結構 3.2 容器背后對應的數據結構 3.3 相關java知識 3.4 什么是算法 3.5 基本關系說明(重要,簡單了解) 前言 …

P9242 [藍橋杯 2023 省 B] 接龍數列--DP【巧妙解決接龍問題】

P9242 [藍橋杯 2023 省 B] 接龍數列--DP 題目 解析什么時候該用 DP?動態規劃 vs 其他方法代碼 題目 解析 這題沒思路,壓根沒想到DP 😦 看了大神的題解,利用dp記錄每一個數結尾的長度,最后再用N-dp中的最大值&#xf…

用《設計模式》的角度優化 “枚舉”

枚舉應該都有用過,枚舉主要的作用是為了方便用戶查找和引用枚舉。 案例一 下面的枚舉邏輯很簡單,就是通過枚舉值返回不同的結果。 public enum OperationEnum {EQUAL_TO,CONTAINS,START_WITH,END_WITH;public String getOperationValue(String value)…

SQL根據分隔符折分不同的內容放到臨時表

SQL Server存儲過程里根據分隔符折分不同的內容放到臨時表里做查詢條件,以下分隔符使用“/”,可修改不同分隔符 --根據分隔符折分不同的內容放到臨時表--------------- SELECT ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS id, LTRIM(RTR…

Ubuntu切換lowlatency內核

文章目錄 一. 前言二. 開發環境三. 具體操作 一. 前言 低延遲內核(Lowlatency Kernel) 旨在為需要低延遲響應的應用程序設計的內核版本。Linux-lowlatency特別適合音頻處理、實時計算、游戲和其他需要及時響應的實時任務。其主要特點是優化了中斷處理、調…

基于Django創建一個WEB后端框架(DjangoRestFramework+MySQL)流程

一、Django項目初始化 1.創建Django項目 Django-admin startproject 項目名 2.安裝 djangorestframework pip install djangorestframework 解釋: Django REST Framework (DRF) 是基于 Django 框架的一個強大的 Web API 框架,提供了多種工具和庫來構建 RESTf…

VUE3開發-9、axios前后端跨域問題解決方案

VUE前端解決跨域問題 前端頁面需要改寫 如果無效,記得重啟服務器 后端c#解決跨域問題 前端js取值,后端c#跨域_c# js跨域-CSDN博客

DailyNotes 增加提醒功能

TODO:準備給 DailyNotes 增加一個提醒功能,準備接入 AI 來做一些事情。試了一下,非常靠譜。 具體 DailyNotes 和 Ollama 的交互方式,可以直接調用命令行,也可以走網絡API。 rayuK2CD9WCYN4 ~ % ollama run deepseek-…

PY32MD320單片機 QFN32封裝,內置多功能三相 NN 型預驅。

PY32MD320單片機是普冉半導體的一款電機專用MCU,芯片采用了高性能的 32 位 ARM Cortex-M0 內核,主要用于電機控制。PY32MD320嵌入高達 64 KB Flash 和 8 KB SRAM 存儲器,最高工作頻率 48 MHz。PY32MD320單片機的工作溫度范圍為 -40 ~ 105 ℃&…

OpenManus介紹及本地部署體驗

1.OpenManus介紹 OpenManus,由 MetaGPT 團隊精心打造的開源項目,于2025年3月發布。它致力于模仿并改進 Manus 這一封閉式商業 AI Agent 的核心功能,為用戶提供無需邀請碼、可本地化部署的智能體解決方案。換句話說,OpenManus 就像…

【貪心算法】簡介

1.貪心算法 貪心策略:解決問題的策略,局部最優----》全局最優 (1)把解決問題的過程分成若干步 (2)解決每一步的時候,都選擇當前看起來的“最優”的算法 (3)“希望”得…

springboot知識點以及源碼解析(2)

web開發--靜態規則與定制化 springboot對靜態資源的映射規則:在類路徑下面定義目錄static或public或resources或者META-INF/resources,訪問時項目根目錄靜態資源的名稱 在springboot中,如果項目中存在同名的靜態資源和同名的動態資源。那么我…

C++:string容器(下篇)

1.string淺拷貝的問題 // 為了和標準庫區分,此處使用String class String { public :/*String():_str(new char[1]){*_str \0;}*///String(const char* str "\0") // 錯誤示范//String(const char* str nullptr) // 錯誤示范String(const char* str …

使用 vxe-table 導出 excel,支持帶數值、貨幣、圖片等帶格式導出

使用 vxe-table 導出 excel,支持帶數值、貨幣、圖片等帶格式導出,通過官方自動的導出插件 plugin-export-xlsx 實現導出功能 查看官網:https://vxetable.cn gitbub:https://github.com/x-extends/vxe-table gitee:htt…

JavaScript數據類型和內存空間

一、JavaScript 數據類型 基本數據類型:字符串(String)、數字(Number)、布爾(Boolean)、空(Null)、未定義(Undefined)、Symbol 引用數據類型:對象(Object)、數組(Array)、函數(Fun…

DNS Beaconing

“DNS Beaconing” 是一種隱蔽的網絡通信技術,通常與惡意軟件(如木馬、僵尸網絡)相關。攻擊者通過定期發送 DNS請求 到受控的域名服務器(C&C服務器),實現與惡意軟件的隱蔽通信、數據傳輸或指令下發。由…

python中采用opencv作常規的圖片處理的方法~~~

在python中,我們經常會需要對圖片做灰度/二值化/模糊等處理,這時候opencv就是我們的好幫手了,下面我來介紹一下相關用法: 首先,需要安裝opencv-python庫: 然后,在你的代碼中引用: import cv2 最后就是代碼了&#x…

CmBacktrace的學習跟移植思路

學習移植CmBacktrace需要從理解其核心功能、適用場景及移植步驟入手,結合理論學習和實踐操作。以下是具體的學習思路與移植思路: 一、學習思路 理解CmBacktrace的核心功能 CmBacktrace是針對ARM Cortex-M系列MCU的錯誤追蹤庫,支持自動診斷Har…