大模型微調---Prompt-tuning微調

目錄

    • 一、前言
    • 二、Prompt-tuning實戰
      • 2.1、下載模型到本地
      • 2.2、加載模型與數據集
      • 2.3、處理數據
      • 2.4、Prompt-tuning微調
      • 2.5、訓練參數配置
      • 2.6、開始訓練
    • 三、模型評估
    • 四、完整訓練代碼

一、前言

Prompt-tuning通過修改輸入文本的提示(Prompt)來引導模型生成符合特定任務或情境的輸出,而無需對模型的全量參數進行微調。
在這里插入圖片描述
Prompt-Tuning 高效微調只會訓練新增的Prompt的表示層,模型的其余參數全部固定,其核心在于將下游任務轉化為預訓練任務

在這里插入圖片描述
新增的 Prompt 內容可以分為 Hard PromptSoft Prompt 兩類:

  • Soft prompt 通常指的是一種較為寬泛或模糊的提示,允許模型在生成結果時有更大的自由度,通常用于啟發模型進行創造性的生成;
  • Hard prompt 是一種更為具體和明確的提示,要求模型按照給定的信息生成精確的結果,通常用于需要模型提供準確答案的任務;

Soft Prompt 在 peft 中一般是隨機初始化prompt的文本內容,而 Hard prompt 則一般需要設置具體的提示文本內容;

對于不同任務的Prompt的構建示例如下:
在這里插入圖片描述

例如,假設我們有興趣將英語句子翻譯成德語。我們可以通過各種不同的方式詢問模型,如下圖所示。

1)“Translate the English sentence ‘{english_sentence}’ into German: {german_translation}”
2)“English: ‘{english sentence}’ | German: {german translation}”
3)“From English to German:‘{english_sentence}’-> {german_translation}”

上面說明的這個概念被稱為硬提示調整

軟提示調整(soft prompt tuning)將輸入標記的嵌入與可訓練張量連接起來,該張量可以通過反向傳播進行優化,以提高目標任務的建模性能。

例如下方偽代碼:

# 定義可訓練的軟提示參數
# 假設我們有 num_tokens 個軟提示 token,每個 token 的維度為 embed_dim
soft_prompt = torch.nn.Parameter(torch.rand(num_tokens, embed_dim)  # 隨機初始化軟提示向量
)# 定義一個函數,用于將軟提示與原始輸入拼接
def input_with_softprompt(x, soft_prompt):# 假設 x 的維度為 (batch_size, seq_len, embed_dim)# soft_prompt 的維度為 (num_tokens, embed_dim)# 將 soft_prompt 在序列維度上與 x 拼接# 拼接后的張量維度為 (batch_size, num_tokens + seq_len, embed_dim)x = concatenate([soft_prompt, x], dim=seq_len)return x# 將包含軟提示的輸入傳入模型
output = model(input_with_softprompt(x, soft_prompt))
  1. 軟提示參數:

使用 torch.nn.Parameter 將隨機初始化的向量注冊為可訓練參數。這意味著在訓練過程中,soft_prompt 中的參數會隨梯度更新而優化。

  1. 拼接輸入:

函數 input_with_softprompt 接收原始輸入 x(通常是嵌入后的 token 序列)和 soft_prompt 張量。通過 concatenate(偽代碼中使用此函數代指張量拼接操作),將軟提示向量沿著序列長度維度與輸入拼接在一起。

  1. 傳遞給模型:

將包含軟提示的輸入張量傳給模型,以引導模型在執行特定任務(如分類、生成、QA 等)時更好地利用這些可訓練的軟提示向量。

二、Prompt-tuning實戰

預訓練模型與分詞模型——Qwen/Qwen2.5-0.5B-Instruct
數據集——lyuricky/alpaca_data_zh_51k

2.1、下載模型到本地

# 下載數據集
dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
ds = load_dataset("./data/alpaca_data", split="train")# 下載分詞模型
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
tokenizer.save_pretrained("./local_tokenizer_model")#下載與訓練模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下載模型的路徑torch_dtype="auto",low_cpu_mem_usage=True,cache_dir="./local_model_cache"  # 指定本地緩存目錄
)

2.2、加載模型與數據集

#加載分詞模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")# 加載數據集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")# 記載模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",torch_dtype="auto",device_map="cuda:0")

2.3、處理數據

"""
并將其轉換成適合用于模型訓練的輸入格式。具體來說,
它將原始的輸入數據(如用戶指令、用戶輸入、助手輸出等)轉換為模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")if example["output"] is not None:response = tokenizer(example["output"] + tokenizer.eos_token)else:returninput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 分詞
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)

2.4、Prompt-tuning微調

soft Prompt

# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt會隨機初始化

Hard Prompt

# Hard Prompt
prompt = "下面是一段人與機器人的對話。"
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,prompt_tuning_init_text=prompt,num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),tokenizer_name_or_path="../local_tokenizer_model")

加載peft配置

peft_model = get_peft_model(model, config)print(peft_model.print_trainable_parameters())

在這里插入圖片描述
可以看到要訓練的模型相比較原來的全量模型要少很多

2.5、訓練參數配置

# 配置模型參數
args = TrainingArguments(output_dir="chatbot",   # 訓練模型的輸出目錄per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=1,
)

2.6、開始訓練

# 創建訓練器
trainer = Trainer(args=args,model=model,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer_model, padding=True )
)
# 開始訓練
trainer.train()

可以看到 ,損失有所下降

在這里插入圖片描述

三、模型評估

# 模型推理
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipelinemodel = AutoModelForCausalLM.from_pretrained("../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775", low_cpu_mem_usage=True)
peft_model = PeftModel.from_pretrained(model=model, model_id="./chatbot/checkpoint-643")
peft_model = peft_model.cuda()#加載分詞模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")
ipt = tokenizer_model("Human: {}\n{}".format("我們如何在日常生活中減少用水?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
print(tokenizer_model.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))print("-----------------")
#預訓練的管道流
# 構建prompt
ipt = "Human: {}\n{}".format("我們如何在日常生活中減少用水?", "").strip() + "\n\nAssistant: "
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer_model)
output = pipe(ipt, max_length=256, do_sample=True, truncation=True)
print(output)

訓練了一輪,感覺效果不大,可以增加訓練輪數試試
在這里插入圖片描述

四、完整訓練代碼


from datasets import load_dataset
from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, \DataCollatorForSeq2Seq, Trainer# 下載數據集
# dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
# ds = load_dataset("./data/alpaca_data", split="train")
# print(ds[0])# 下載分詞模型
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
# tokenizer.save_pretrained("./local_tokenizer_model")#下載與訓練模型
# model = AutoModelForCausalLM.from_pretrained(
#     pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下載模型的路徑
#     torch_dtype="auto",
#     low_cpu_mem_usage=True,
#     cache_dir="./local_model_cache"  # 指定本地緩存目錄
# )#加載分詞模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")# 加載數據集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")# 記載模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",torch_dtype="auto",device_map="cuda:0")# 處理數據
"""
并將其轉換成適合用于模型訓練的輸入格式。具體來說,
它將原始的輸入數據(如用戶指令、用戶輸入、助手輸出等)轉換為模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")if example["output"] is not None:response = tokenizer(example["output"] + tokenizer.eos_token)else:returninput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 分詞
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)prompt = "下面是一段人與機器人的對話。"# prompt-tuning
# Soft Prompt
# config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt會隨機初始化
# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,prompt_tuning_init_text=prompt,num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),tokenizer_name_or_path="../local_tokenizer_model")peft_model = get_peft_model(model, config)print(peft_model.print_trainable_parameters())# 訓練參數args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=8,logging_steps=10,num_train_epochs=1
)# 創建訓練器
trainer = Trainer(model=peft_model, args=args, train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_model, padding=True))# 開始訓練
trainer.train()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63635.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63635.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63635.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Visual Studio 、 MSBuild 、 Roslyn 、 .NET Runtime、SDK Tools之間的關系

1. Visual Studio Visual Studio 是一個集成開發環境(IDE),為開發者提供代碼編寫、調試、測試和發布等功能。它內置了 MSBuild、Roslyn 和 SDK Tools,并提供圖形化界面來方便開發者進行項目管理和構建。與其他組件的關系&#xf…

Winnows基礎(2)

Target 了解常見端口及服務,熟練cmd命令,編寫簡單的 .bat 病毒程序。 Trail 常見服務及端口 80 web 80-89 可能是web 443 ssl心臟滴血漏洞以及一些web漏洞測試 445 smb 1433 mssql 1521 oracle 2082/2083 cpanel主機管理系統登陸(國外用的…

Edge Scdn用起來怎么樣?

Edge Scdn:提升網站安全與性能的最佳選擇 在當今互聯網高速發展的時代,各種網絡攻擊層出不窮,特別是針對網站的DDoS攻擊威脅,幾乎每個行業都可能成為目標。為了確保網站的安全性與穩定性,越來越多的企業開始關注Edge …

通信技術以及5G和AI保障電網安全與網絡安全

摘 要:電網安全是電力的基礎,隨著智能電網的快速發展,越來越多的ICT信息通信技術被應用到電力網絡。本文分析了歷史上一些重大電網安全與網絡安全事故,介紹了電網安全與網絡安全、通信技術與電網安全的關系以及相應的電網安全標準…

梯度(Gradient)和 雅各比矩陣(Jacobian Matrix)的區別和聯系:中英雙語

雅各比矩陣與梯度:區別與聯系 在數學與機器學習中,梯度(Gradient) 和 雅各比矩陣(Jacobian Matrix) 是兩個核心概念。雖然它們都描述了函數的變化率,但應用場景和具體形式有所不同。本文將通過…

時間序列預測論文閱讀和相關代碼庫

時間序列預測論文閱讀和相關代碼庫列表 MLP-based的時間序列預測資料DLinearUnetTSFPDMLPLightTS 代碼庫以及論文庫:Time-Series-LibraryUnetTSFLightTS MLP-based的時間序列預測資料 我會定期把我的所有時間序列預測論文有關的資料鏈接全部同步到這個文章中&#…

引言和相關工作的區別

引言和相關工作的區別 引言 目的與重點 引言主要是為了引出研究的主題,向讀者介紹為什么這個研究問題是重要且值得關注的。它通常從更廣泛的背景出發,闡述研究領域的現狀、面臨的問題或挑戰,然后逐漸聚焦到論文要解決的具體問題上。例如,在這篇關于聯邦學習數據交易方案的…

GitLab分支管理策略和最佳實踐

分支管理是 Git 和 GitLab 中非常重要的部分,合理的分支管理可以幫助團隊更高效地協作和開發。以下是一些細化的分支管理策略和最佳實踐: 1. 分支命名規范 ? 主分支:通常命名為 main 或 master,用于存放穩定版本的代碼。 ? …

批量提取zotero的論文構建知識庫做問答的大模型(可選)——含轉存PDF-分割統計PDF等

文章目錄 提取zotero的PDF上傳到AI平臺保留文件名代碼分成20個PDF視頻講解 提取zotero的PDF 右鍵查看目錄 發現目錄為 C:\Users\89735\Zotero\storage 寫代碼: 掃描路徑‘C:\Users\89735\Zotero\storage’下面的所有PDF文件,全部復制一份匯總到"C:\Users\89735\Downl…

LabVIEW實現NB-IoT通信

目錄 1、NB-IoT通信原理 2、硬件環境部署 3、程序架構 4、前面板設計 5、程序框圖設計 6、測試驗證 本專欄以LabVIEW為開發平臺,講解物聯網通信組網原理與開發方法,覆蓋RS232、TCP、MQTT、藍牙、Wi-Fi、NB-IoT等協議。 結合實際案例,展示如何利用LabVIEW和常用模塊實現物聯網…

面試題整理9----談談對k8s的理解2

面試題整理9----談談對k8s的理解2 1. Service 資源1.1 ServiceClusterIPNodePortLoadBalancerIngressExternalName 1.2 Endpoints1.3 Ingress1.4 EndpointSlice1.5 IngressClass 2. 配置和存儲資源2.1 ConfigMap2.2 Secret2.3 PersistentVolume2.4 PersistentVolumeClaim2.5 St…

精準采集整車信號:風丘混合動力汽車工況測試

一 背景 混合動力汽車是介于純電動汽車與燃油汽車兩者之間的一種新能源汽車。它既包含純電動汽車無污染、啟動快的優勢,又擁有燃油車續航便捷、不受電池容量限制的特點。在當前環境下,混合動力汽車比純電動汽車更符合目前的市場需求。 然而&#xff0c…

帶標題和不帶標題的內部表

什么是工作區? 什么是工作區?簡單來說,工作區是單行數據。它們應具有與任何內部表相同的格式。它用于一次處理一行內部表中的數據。 內表和工作區的區別 ? 一圖勝千言 內表的類型 有兩種類型的內表: 帶 Header 行…

【圖像分類實用腳本】數據可視化以及高數量類別截斷

圖像分類時,如果某個類別或者某些類別的數量遠大于其他類別的話,模型在計算的時候,更傾向于擬合數量更多的類別;因此,觀察類別數量以及對數據量多的類別進行截斷是很有必要的。 1.準備數據 數據的格式為圖像分類數據集…

【Leetcode 每日一題】2545. 根據第 K 場考試的分數排序

問題背景 班里有 m m m 位學生,共計劃組織 n n n 場考試。給你一個下標從 0 0 0 開始、大小為 m n m \times n mn 的整數矩陣 s c o r e score score,其中每一行對應一位學生,而 s c o r e [ i ] [ j ] score[i][j] score[i][j] 表示…

React系列(八)——React進階知識點拓展

前言 在之前的學習中,我們已經知道了React組件的定義和使用,路由配置,組件通信等其他方法的React知識點,那么本篇文章將針對React的一些進階知識點以及React16.8之后的一些新特性進行講解。希望對各位有所幫助。 一、setState &am…

PCIe_Host驅動分析_地址映射

往期內容 本文章相關專欄往期內容,PCI/PCIe子系統專欄: 嵌入式系統的內存訪問和總線通信機制解析、PCI/PCIe引入 深入解析非橋PCI設備的訪問和配置方法 PCI橋設備的訪問方法、軟件角度講解PCIe設備的硬件結構 深入解析PCIe設備事務層與配置過程 PCIe的三…

【閱讀記錄-章節6】Build a Large Language Model (From Scratch)

文章目錄 6. Fine-tuning for classification6.1 Different categories of fine-tuning6.2 Preparing the dataset第一步:下載并解壓數據集第二步:檢查類別標簽分布第三步:創建平衡數據集第四步:數據集拆分 6.3 Creating data loa…

ip_output函數

ip_output函數是Linux內核(特別是網絡子系統)中用于發送IPv4數據包的核心函數。以下是一個示例實現,并附上詳細的中文講解: int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb) {struct iphdr *iph; /* 構建IP頭部 */iph = ip_hdr(skb);/* 設置服務…

梳理你的思路(從OOP到架構設計)_簡介設計模式

目錄 1、 模式(Pattern) 是較大的結構?編輯 2、 結構形式愈大 通用性愈小?編輯 3、 從EIT造形 組合出設計模式 1、 模式(Pattern) 是較大的結構 組合與創新 達芬奇說:簡單是複雜的終極形式 (Simplicity is the ultimate form of sophistication) —Leonardo d…