hugging face筆記:PEFT

1 介紹

  • PEFT (Parameter-Efficient Fine Tuning) 方法在微調時凍結預訓練模型參數,并在其上添加少量可訓練的參數(稱為適配器)
  • 這些適配器被訓練用來學習特定任務的信息。
  • 這種方法已被證明在內存效率和計算使用上非常高效,同時能產生與完全微調模型相當的結果
  • 使用PEFT訓練的適配器通常比完整模型小一個數量級,這使得分享、存儲和加載它們變得非常方便。
    • 例如,一個OPTForCausalLM模型的適配器權重在Hub上的存儲只有約6MB
    • 相比之下,完整的模型權重可以達到約700MB。

2 加載 PEFT適配器

2.1 直接from_pretrained加載

  • 若要從?Transformers 加載和使用 PEFT 適配器模型,請確保 Hub 存儲庫或本地目錄包含 adapter_config.json 文件和適配器權重
  • 然后,可以使用 AutoModel類加載 PEFT 適配器模型
from transformers import AutoModelForCausalLM, AutoTokenizerpeft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id)

2.2 load_adapter加載

from transformers import AutoModelForCausalLMmodel_id = "facebook/opt-350m"
peft_model_id = "ybelkada/opt-350m-lora"model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)

?2.3 以8位/4位加載

  • bitsandbytes 集成支持 8 位和 4 位精度數據類型,這對于加載大型模型非常有用,因為它節省了內存
  • 在 from_pretrained() 中添加 load_in_8bit 或 load_in_4bit 參數,并設置 device_map="auto" 以有效地將模型分配到你的硬件
from transformers import AutoModelForCausalLM, AutoTokenizerpeft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", load_in_8bit=True)

3 添加適配器?

from transformers import AutoModelForCausalLM
from peft import LoraConfigmodel_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id)
#加載這個模型lora_config = LoraConfig(target_modules=["q_proj", "k_proj"],init_lora_weights=False
)
'''
target_modules 參數指定了將 LoRA 適配器應用于模型的哪些部分這里是 "q_proj"(查詢投影)和 "k_proj"(鍵投影)init_lora_weights 設置為 False,意味著在初始化時不加載 LoRA 權重
'''model.add_adapter(lora_config, adapter_name="adapter_1")
#使用 add_adapter 方法將之前配置的 LoRA 適配器添加到模型中,適配器命名為 "adapter_1"model.add_adapter(lora_config, adapter_name="adapter_2")
# 附加具有相同配置的新適配器"adapter_2"

4 設置使用哪個適配器

# 使用 adapter_1
model.set_adapter("adapter_1")
output = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))# 使用 adapter_2
model.set_adapter("adapter_2")
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))

5 啟用和禁用適配器

一旦向模型添加了適配器,可以啟用或禁用適配器模塊

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfigmodel_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello"
inputs = tokenizer(text, return_tensors="pt")
#加載分詞器和初始化輸入model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PeftConfig.from_pretrained(adapter_model_id)
'''
加載了預訓練的基礎模型 facebook/opt-350m 和適配器的配置。PeftConfig.from_pretrained 方法用于加載預定義的適配器配置。
'''peft_config.init_lora_weights = False
model.add_adapter(peft_config)
'''
在添加適配器前,設置 init_lora_weights = False 指明在初始化時不使用預訓練的 LoRA 權重,
而是使用隨機權重。然后將適配器添加到模型中。
'''model.enable_adapters()
output1 = model.generate(**inputs)
#啟用適配器,然后使用啟用了適配器的模型生成文本model.disable_adapters()
output2 = model.generate(**inputs)
#禁用適配器后,再次生成文本以查看不使用適配器時模型的輸出表現tokenizer.decode(output1[0])
'''
'</s>Hello------------------'
'''tokenizer.decode(output2[0])
'''
"</s>Hello, I'm a newbie to this sub. I'm looking for a good place to"
'''

6 訓練PEFT適配器

6.1 舉例:添加lora適配器

6.1.1?定義你的適配器配置

from peft import LoraConfigpeft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",
)
  • lora_alpha=16:指定 LoRA 層的縮放因子。
  • lora_dropout=0.1:設置在 LoRA 層中使用的 dropout 比率,以避免過擬合。
  • r=64:設置每個 LoRA 層的秩,即低秩矩陣的維度。
  • bias="none":指定不在 LoRA 層中使用偏置項。
  • task_type="CAUSAL_LM":設定這個 LoRA 配置是為了因果語言模型任務。

6.1.2 將適配器添加到模型

model.add_adapter(peft_config)

6.1.3將模型傳遞給 Trainer以進行訓練

from transformers import Trainer
trainer = Trainer(model=model, ...)
trainer.train()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/16345.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/16345.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/16345.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

線性模型--普通最小二乘法

線性模型 一、模型介紹二、用于回歸的線性模型2.1 線性回歸&#xff08;普通最小二乘法&#xff09; 一、模型介紹 線性模型是在實踐中廣泛使用的一類模型&#xff0c;該模型利用輸入特征的線性函數進行預測。 二、用于回歸的線性模型 以下代碼可以在一維wave數據集上學習參…

基于51單片機的超聲波液位測量與控制系統

基于51單片機液位控制器 &#xff08;仿真&#xff0b;程序&#xff0b;原理圖PCB&#xff0b;設計報告&#xff09; 功能介紹 具體功能&#xff1a; 1.使用HC-SR04測量液位&#xff0c;LCD1602顯示&#xff1b; 2.當水位高于設定上限的時候&#xff0c;對應聲光報警報警&am…

手機卡該地塊

package demo; package demo; public class Phonetest { public static void main(String[] args) { Phone pnew Phone(); p.brand"小米"; p.price1998.98; System.out.println(…

在業務開發中使用ElasticSearch的指導手冊

文章目錄 該業務為什么需要ElasticSearch? / 該業務需要ElasticSearch的核心功能是哪些&#xff1f;正確示例錯誤示例 如何快速驗證分詞是否能夠滿足業務需求&#xff1f;分詞不滿足&#xff0c;如何自定義分詞&#xff1f; 業務數據的字段類型映射是否合理&#xff1f;實踐中…

MySQL設置表自增步長

在MySQL數據庫管理中&#xff0c;自增字段&#xff08;AUTO_INCREMENT&#xff09;是一種常見且重要的功能&#xff0c;通常用于生成唯一的標識符&#xff08;如主鍵&#xff09;。然而&#xff0c;在多種應用場景下&#xff0c;默認的自增步長&#xff08;1&#xff09;可能無…

【InternLM實戰營第二期筆記】02:大模型全鏈路開源體系與趣味demo

文章目錄 00 環境設置01 部署一個 chat 小模型02 Lagent 運行 InternLM2-chat-7B03 浦語靈筆2 第二節課程視頻與文檔&#xff1a; https://www.bilibili.com/video/BV1AH4y1H78d/ https://github.com/InternLM/Tutorial/blob/camp2/helloworld/hello_world.md 視頻和文檔內容基…

003 CentOS 7.9 mysql8.3.0安裝及配置

文章目錄 Windows PowerShell測試端口安裝及配置1. 下載MySQL安裝包2. 解壓安裝包3. 安裝MySQL4. 啟動MySQL服務5. 獲取并設置MySQL root密碼6. 創建數據庫7. 配置遠程連接&#xff08;可選&#xff09; 卸載mysql檢查并卸載已有的MySQL或MariaDB&#xff1a; https://download…

云計算和大數據處理

文章目錄 1.云計算基礎知識1.1 基本概念1.2 云計算分類 2.大數據處理基礎知識2.1 基礎知識2.3 大數據處理技術 1.云計算基礎知識 1.1 基本概念 云計算是一種提供資源的網絡&#xff0c;使用者可以隨時獲取“云”上的資源&#xff0c;按需求量使用&#xff0c;并且可以看成是無…

AWS安全性身份和合規性之WAF(Web Application Firewall)

AWS WAF&#xff08;Web Application Firewall&#xff09;是一項AWS托管的網絡安全服務&#xff0c;用于保護Web應用程序免受常見的Web攻擊&#xff0c;如SQL注入、跨站腳本&#xff08;XSS&#xff09;、跨站請求偽造&#xff08;CSRF&#xff09;等。 應用場景&#xff1a;…

STM32應用開發進階--IIC總線(SHT20溫濕度+HAL庫_硬件I2C)

實現目標 1、掌握IIC總線基礎知識&#xff1b; 2、會使用軟件模擬IIC總線和使用STM32硬件IIC總線&#xff1b; 3、 學會STM32CubeMX軟件關于IIC的配置; 4、掌握SHT20溫濕度傳感器的驅動&#xff1b; 5、具體目標&#xff1a;&#xff08;1&#xff09;用STM32硬件IIC驅動S…

49 序列化和反序列化

本章重點 理解應用層的作用&#xff0c;初識http協議 理解傳輸層的作用&#xff0c;深入理解tcp的各項特性和機制 對整個tcp/ip協議有系統的理解 對tcp/ip協議體系下的其他重要協議和技術有一定的了解 學會使用一些網絡問題的工具和方法 目錄 1.應用層 2.協議概念 3. 網絡計…

CSRF跨站請求偽造實戰

目錄 一、定義 二、與XSS的區別 三、攻擊要點 四、實戰 一、定義 CSRF (Cross-site request forgery&#xff0c;跨站請求偽造)&#xff0c;攻擊者利用服務器對用戶的信任&#xff0c;從而欺騙受害者去服務器上執行受害者不知情的請求。在CSRF的攻擊場景中&#xff0c;攻擊…

Django模板層——模板引擎配置

作為Web 框架&#xff0c;Django 需要一種很便利的方法以動態地生成HTML。最常見的做法是使用模板。 模板包含所需HTML 輸出的靜態部分&#xff0c;以及一些特殊的語法&#xff0c;描述如何將動態內容插入。 模板引擎配置 模板引擎使用該TEMPLATES設置進行配置。這是一個配置列…

C++數據結構——哈希桶HashBucket

目錄 一、前言 1.1 閉散列 1.2 開散列 1.3 string 與 非 string 二、哈希桶的構成 2.1 哈希桶的節點 2.2 哈希桶類 三、 Insert 函數 3.1 無需擴容時 3.2 擴容 復用 Insert&#xff1a; 逐個插入&#xff1a; 優缺點比對&#xff1a; 第一種寫法優點 第一種寫法…

gfast:基于全新Go Frame 2.3+Vue3+Element Plus構建的全棧前后端分離管理系統

gfast&#xff1a;基于全新Go Frame 2.3Vue3Element Plus構建的全棧前后端分離管理系統 隨著信息技術的飛速發展和數字化轉型的深入&#xff0c;后臺管理系統在企業信息化建設中扮演著越來越重要的角色。為了滿足市場對于高效、靈活、安全后臺管理系統的需求&#xff0c;gfast應…

OpenUI 可視化 AI:打造令人驚艷的前端設計!

https://openui.fly.dev/ai/new 可視化UI的新時代&#xff1a;通過人工智能生成前端代碼 許久未更新, 前端時間在逛github&#xff0c;發現一個挺有的意思項目&#xff0c;通過口語化方式生成前端UI頁面&#xff0c;能夠直觀的看到效果&#xff0c;下面來給大家演示下 在現代…

SAP FS00如何導出會計總賬科目表

輸入T-code : S_ALR_87012333 根據‘FS00’中找到的總賬科目&#xff0c;進行篩選執行 點擊左上角的列表菜單&#xff0c;選擇‘電子表格’導出即可

echarts-地圖

使用地圖的三種的方式&#xff1a; 注冊地圖(用json或svg,注冊為地圖)&#xff0c;然后使用map地圖使用geo坐標系&#xff0c;地圖注冊后不是直接使用&#xff0c;而是注冊為坐標系。直接使用百度地圖、高德地圖&#xff0c;使用百度地圖或高德地圖作為坐標系。 用json或svg注…

C++中string類的初步介紹

C語言中的字符串 在C語言中&#xff0c;字符串是以\0結尾的一些字符的集合&#xff0c;C標準庫中提供了一系列str系列的庫函數&#xff0c;但這些庫函數與字符串是分離的&#xff0c;不符合面向對象的編程思想。 string類的大致介紹 1.string是表示字符串的字符串類 2.stri…

GpuMall智算云:meta-llama/llama3/Llama3-8B-Instruct-WebUI

LLaMA 模型的第三代&#xff0c;是 LLaMA 2 的一個更大和更強的版本。LLaMA 3 擁有 35 億個參數&#xff0c;訓練在更大的文本數據集上GpuMall智算云 | 省錢、好用、彈性。租GPU就上GpuMall,面向AI開發者的GPU云平臺 Llama 3 的推出標志著 Meta 基于 Llama 2 架構推出了四個新…