【大模型】微調實戰—使用 ORPO 微調 Llama 3

ORPO 是一種新穎微調(fine-tuning)技術,它將傳統的監督微調(supervised fine-tuning)和偏好對齊(preference alignment)階段合并為一個過程。這減少了訓練所需的計算資源和時間。此外,實證結果表明,ORPO 在各種模型規模和基準測試(benchmarks)上優于其他對齊方法。
在本文中,我們將使用 ORPO 和 TRL 庫對新的 Llama 3 8B 模型進行微調。

ORPO

指令微調(instruction tuning)和偏好對齊(preference alignment)是使LLM適應特定任務的基本技術。傳統上,這涉及一個多階段的過程:1/ 在指令上進行監督微調(Supervised Fine-Tuning, SFT),以使模型適應目標領域,然后 2/ 使用偏好對齊方法,如基于人類反饋的強化學習(Reinforcement Learning with Human Feedback, RLHF)或直接偏好優化(Direct Preference Optimization, DPO),以增加生成首選響應而非被拒絕響應的可能性。
在這里插入圖片描述

然而,研究人員發現了這種方法的局限性。雖然 SFT 有效地使模型適應所需的領域,但它無意中增加了在首選答案的同時生成不需要的答案的可能性。這就是為什么偏好調整階段對于擴大首選輸出和拒絕輸出的可能性之間的差距是必要的。
ORPO 由 Hong 和 Lee (2024) 提出,通過將指令調整和偏好對齊結合到一個單一的整體訓練過程中,為這個問題提供了一個優雅的解決方案。 ORPO 修改了標準語言建模目標,將負對數似然損失與優勢比 (OR) 項相結合。這種 OR 損失對被拒絕的響應進行弱懲罰,同時對首選響應進行強烈獎勵,從而使模型能夠同時學習目標任務并與人類偏好保持一致。
在這里插入圖片描述
ORPO 已在主要微調庫中實現,如 TRL、Axolotl 和 LLaMA-Factory。在下一節中,我們將了解如何與 TRL 一起使用。

使用 ORPO 微調 Llama 3

Llama 3 是Meta開發的最新大型語言模型(LLM)家族。該模型在一個包含15萬億個標記的數據集上進行了訓練(相比之下,Llama 2 的訓練數據集為2萬億個標記)。目前已經發布了兩種模型尺寸:一個是擁有70B參數的模型,另一個是較小的8B參數模型。70B參數的模型已經展示了令人印象深刻的性能,在MMLU基準測試中得分為82,在HumanEval基準測試中得分為81.7。
Llama 3 模型還將上下文長度增加到了8,192個標記(相比之下,Llama 2 為4,096個標記),并且有可能通過RoPE擴展到32k。此外,這些模型使用了一種新的分詞器,具有128K標記的詞匯量,從而減少了編碼文本所需的標記數量15%。這種詞匯量的增加也解釋了參數從70億增加到80億。
ORPO 需要一個偏好數據集,包括提示、選擇的答案和拒絕的答案。在此示例中,我們將使用 mlabonne/orpo-dpo-mix-40k ,它是以下高質量 DPO 數據集的組合:

  • argilla/distilabel-capybara-dpo-7k-binarized: highly scored chosen answers >=5 (2,882 samples)
  • argilla/distilabel-intel-orca-dpo-pairs: highly scored chosen answers>=9, not in GSM8K (2,299 samples)
  • argilla/ultrafeedback-binarized-preferences-cleaned: highly scoredchosen answers >=5 (22,799 samples)
  • argilla/distilabel-math-preference-dpo: highly scored chosen answers>=9 (2,181 samples)
  • unalignment/toxic-dpo-v0.2 (541 samples)
  • M4-ai/prm_dpo_pairs_cleaned (7,958 samples)
  • jondurbin/truthy-dpo-v0.1 (1,016 samples)

首先安裝所需的庫:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安裝完成后,我們可以導入必要的庫并登錄W&B(可選)

import gc
import osimport torch
import wandb
from datasets import load_dataset
# from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,TrainingArguments,pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format# wb_token = userdata.get('wandb')
# wandb.login(key=wb_token)

如果您有最新的 GPU,還應該能夠使用 Flash Attention 庫將默認的 eager Attention 實現替換為更高效的實現。

if torch.cuda.get_device_capability()[0] >= 8:#!pip install -qqq flash-attnattn_implementation = "flash_attention_2"torch_dtype = torch.bfloat16
else:attn_implementation = "eager"torch_dtype = torch.float16

接下來,我們將借助bitsandbytes 以 4 位精度加載 Llama 3 8B 模型。然后,我們使用 QLoRA 的 PEFT 設置 LoRA 配置。我還使用方便的 setup_chat_format() 函數來修改模型和標記生成器以支持 ChatML。它會自動應用此聊天模板,添加特殊標記,并調整模型嵌入層的大小以匹配新的詞匯表大小。
請注意,您需要提交訪問 meta-llama/Meta-Llama-3-8B 的請求并登錄您的 Hugging Face 帳戶。或者,您可以加載模型的非門控副本,例如 NousResearch/Meta–Llama-3-8B。(我選擇手動從NousResearch/Meta–Llama-3-8B下載)

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"# QLoRA config
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch_dtype,bnb_4bit_use_double_quant=True,
)# LoRA config
peft_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)# Load model
model = AutoModelForCausalLM.from_pretrained(base_model,quantization_config=bnb_config,device_map="auto",attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

現在模型已準備好進行訓練,我們可以處理數據集了。我們加載 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函數將“chosen”和“rejected”列轉換為 ChatML 格式。請注意,我僅使用 1,00 個樣本,而不是整個數據集,因為運行時間太長。(我選擇手動下載)

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(100))def format_chat_template(row):row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)return rowdataset = dataset.map(format_chat_template,num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我們需要設置一些超參數: * learning_rate :與傳統的 SFT 甚至 DPO 相比,ORPO 使用非常低的學習率。 8e-6這個值來自原始論文,大致對應于SFT學習率1e-5和DPO學習率5e-6。我建議將其增加到 1e-6 左右以進行真正的微調。 * beta :即論文中的 𝜆 參數,默認值為0.1。原始論文的附錄顯示了如何通過消融研究選擇它。 * 其他參數,如 max_length 和批量大小設置為使用盡可能多的可用 VRAM(此配置中約為 20 GB)。理想情況下,我們會訓練模型 3-5 個 epoch,但這里我們堅持使用 1 個 epoch。
最后,我們可以使用 ORPOTrainer 來訓練模型,它充當包裝器。

orpo_args = ORPOConfig(learning_rate=8e-6,beta=0.1,lr_scheduler_type="linear",max_length=1024,max_prompt_length=512,per_device_train_batch_size=2,per_device_eval_batch_size=2,gradient_accumulation_steps=4,optim="paged_adamw_8bit",num_train_epochs=1,evaluation_strategy="steps",eval_steps=0.2,logging_steps=1,warmup_steps=10,report_to="wandb",output_dir="./results/",
)trainer = ORPOTrainer(model=model,args=orpo_args,train_dataset=dataset["train"],eval_dataset=dataset["test"],peft_config=peft_config,tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

中間需要選擇是否使用W&B,不會使用,我選擇不使用
在這里插入圖片描述
完成了 Llama 3 的快速微調:mlabonne/OrpoLlama-3-8B
在這里插入圖片描述

生成目錄:
在這里插入圖片描述

合并完整模型到本地:

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model,low_cpu_mem_usage=True,return_dict=True,torch_dtype=torch.float16,device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()# Save the merged model and tokenizer to local directory
local_save_directory = "new_model"
model.save_pretrained(local_save_directory)
tokenizer.save_pretrained(local_save_directory)

得到和初始模型一樣結構的微調模型;
在這里插入圖片描述
完整教程:https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html
本文使用代碼對原代碼改了一部分。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/43814.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/43814.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/43814.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用微pe裝系統

本文僅作為記錄,不作為教程。 今天心血來潮想下點游戲玩玩,一看之前分的200gc盤已經紅了,再加上大學之后這個筆記本已經用得很少了,于是打算重裝電腦。 參考: 微PE輔助安裝_嗶哩嗶哩_bilibil… 1.下載微pe和win10系統到U盤 我這…

Xilinx zc706 USB電路解析

作者 QQ群:852283276 微信:arm80x86 微信公眾號:青兒創客基地 B站:主頁 https://space.bilibili.com/208826118 參考 USB OTG檢測原理 USB3320 USB_ID為低電平時候,為host模式,USB_ID為懸空(高…

python-23-零基礎自學python open()和replace()函數運用

學習內容:《python編程:從入門到實踐》第二版練習10-2 知識點: 打開文件,replace()替換文件內容,open(), 練習內容: 練習10-2:C語言學習筆記 可使用方法replace()將字符串中的特定單詞都替換為另一個單…

云計算環境下的等級保護測評

概述 云計算環境下的等級保護測評是一個涵蓋多個層面的綜合性評估活動,它不僅包括了傳統的信息系統安全等級保護測評內容,還需要考慮到云計算特有的安全特性和挑戰。隨著云計算技術的迅猛發展和廣泛應用,其在政務、金融、教育等行業中的角色日…

代碼隨想錄訓練營第三十一天 56合并區間 738單調遞增的數字

第一題: 原題鏈接:56. 合并區間 - 力扣(LeetCode) 思路:首先還是排序; 然后定義一個二維數組存放結果,先將第一個元素存放到結果數組中,然后從第一個元素開始遍歷整個數組。 當前…

kafka系列之offset超強總結及消費后不提交offset情況的分析總結

概述 每當我們調用Kafka的poll()方法或者使用Spring的KafkaListener(其實底層也是poll()方法)注解消費Kafka消息時,它都會返回之前被寫入Kafka的記錄,即我們組中的消費者還沒有讀過的記錄。 這意味著我們有一種方法可以跟蹤該組消費者讀取過的記錄。 如前…

6.824/6.5840 的Debugging by Pretty Printing配置

TA的原文在:Debugging by Pretty Printing (josejg.com) 為了在WSL2中配置好打印運行日志,我可是忙活了一下午。可惡的log配置 首先是安裝rich庫Textualize/rich: Rich is a Python library for rich text and beautiful formatting in the terminal. …

用于視頻生成的擴散模型

學習自https://lilianweng.github.io/posts/2024-04-12-diffusion-video/ 文章目錄 3D UNet和DiTVDMImagen VideoSora 調整圖像模型生成視頻Make-A-Video(對視頻數據微調)Tune-A-VideoGen-1視頻 LDMSVD穩定視頻擴散 免訓練Text2Video-ZeroControlVideo 參…

需求分析|泳道圖 ProcessOn教學

文章目錄 1.為什么使用泳道圖2.具體例子一、如何繪制確定好泳道中樞的角色在中央基于事實來繪制過程不要糾結美觀先畫主干處理流程再畫分支處理流程一個圖表達不完,切分子流程過程數不超25 ,A4紙的幅面處理過程過程用動詞短語最后美化并加上序號酌情加上…

leetcode hot 100 刷題記錄

題目300:最長遞增子序列(NO) 解題思路:動態規劃,就是dp[i]的運用,這里dp[i]表示第i個元素為結尾的最長子序列。 給你一個整數數組 nums ,找到其中最長嚴格遞增子序列的長度。 子序列 是由數組…

后端——全局異常處理

一、老辦法try-catch 當我們執行一些錯誤操作導致程序報錯時,程序會捕捉到異常報錯,這個異常會存在一個Exception對象里 那我們在spring boot工程開發時,當我們執行一個sql查詢時報錯了,那就會從最底層的Mapper層捕捉到Exceptio…

Android應用程序調試Logcat的使用

Android的程序調試主要使用Logcat進行,本節主要介紹Logcat的使用。 開啟調試模式 使用Android Studio進行程序調試,首先需要連接虛擬Android設備或真實Android設備,設備上需要啟用調試功能。 虛擬Android設備默認情況下會啟用調試功能。對…

C++ 入門03:函數與作用域

往期回顧: C 入門01:初識 C-CSDN博客C 入門02:控制結構和循環-CSDN博客 一、前言 在前面的文章學習中,我們了解了C語言的基礎,包括如何定義變量來存儲數據,以及如何利用輸入輸出流實現程序與用戶之間的無縫…

華為機考真題 -- 找朋友

題目描述: 在學校中,N 個小朋友站成一隊, 第 i 個小朋友的身高為 height[i],第 i 個小朋友可以看到的第一個比自己身高更高的小朋友 j,那么 j 是 i 的好朋友(要求 j >i)。請重新生成一個列表,對應位置的輸出是每個小朋友的好朋友位置,如果沒有看到好朋友,請在該位置…

微軟清華提出全新預訓練范式,指令預訓練讓8B模型實力暴漲!實力碾壓70B模型

現在的大模型訓練通常會包括兩個階段: 一是無監督的預訓練,即通過因果語言建模預測下一個token生成的概率。該方法無需標注數據,這意味著可以利用大規模的數據學習到語言的通用特征和模式。 二是指令微調,即通過自然語言指令構建…

Python面試題:請解釋什么是鴨子類型(duck typing)?

鴨子類型(Duck Typing)是一種動態類型語言中的概念,它基于對象的行為(方法和屬性)而不是其實際類型進行判斷。這個概念源自詹姆斯惠特科姆賴利的諺語: “如果它走起來像鴨子,叫起來像鴨子&#…

通過高德地圖 JS API實現單擊鼠標進行標注

效果圖: 核心代碼: <template><a-modal title="選擇地圖所在位置" :width="width" :visible="visible" @ok="handleOk" @cancel="handleCancel" cancelText="關閉"><div class="location-…

場外期權有交割日嗎?場外期權應該怎么交割?

今天帶你了解場外期權有交割日嗎&#xff1f;場外期權應該怎么交割&#xff1f;場外個股期權是一種非標準化的金融衍生品&#xff0c;它允許投資者在未來某一特定日期以特定價格買入或賣出某一特定股票。 交割日就是買賣雙方進行交割的日期,期權合約具有到期日,到期日的后一天…

WEB安全-文件上傳漏洞

1 需求 2 接口 3 MIME類型 在Web開發中&#xff0c;MIME&#xff08;Multipurpose Internet Mail Extensions&#xff09;類型用于標識和表示文檔的格式。這些類型在HTTP請求和響應頭中扮演著重要的角色&#xff0c;告訴瀏覽器如何解釋和處理接收到的資源12。 以下是一些Web開發…

ChatGPT:Java Stream 的疑問

ChatGPT&#xff1a;Java Stream 的疑問 解釋一下 List<SupplierVm> collect tSupplierPage.getRecords().stream().map(item ->{SupplierVm supplierVm new SupplierVm();BeanUtils.copyProperties(item, supplierVm);return supplierVm;}).collect(Collectors.to…