使用 LLaMA 3 8B 微調一個 Reward Model:從入門到實踐

本文將介紹如何基于 Meta 的 LLaMA 3 8B 模型構建并微調一個 Reward Model,它是構建 RLHF(基于人類反饋的強化學習)系統中的關鍵一環。我們將使用 Hugging Face 的 transformerstrlpeft 等庫,通過參數高效微調(LoRA)實現高質量 Reward Model 的訓練。

什么是 Reward Model?

Reward Model(RM)是 RLHF 流程中的評分器,它學習人類偏好:在多個候選回答中判斷哪個更符合用戶意圖。訓練目標是使模型給出更高 reward 分數的輸出更符合人類偏好,常用于后續的強化學習微調如 PPO、DPO 等。

技術選型

  • 模型基座LLaMA 3 8B(你需要有模型訪問權限)

  • 微調方法LoRA(Parameter-Efficient Fine-Tuning)

  • 訓練庫:trl (Transformers Reinforcement Learning)

  • 數據格式:偏好比較數據(prompt, chosen, rejected)

數據格式示例

Reward Model 使用的是 pairwise preference 數據,基本格式如下:

{"prompt": "什么是人工智能?","chosen": "人工智能是讓機器具備模擬人類智能的能力,例如學習、推理、感知等。","rejected": "人工智能就是讓機器變得更厲害。"
}
  • prompt 是輸入問題

  • chosen 是較優回答

  • rejected 是較差回答

我們訓練模型區分出“好回答”和“不好回答”。

安裝依賴

pip install transformers peft trl accelerate datasets bitsandbytes

加載 LLaMA 3 模型

我們使用 Hugging Face 的 transformers 加載 LLaMA 3,并通過 LoRA 應用微調。

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_modelmodel_name = "meta-llama/Meta-Llama-3-8B"tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # 處理 paddingmodel = AutoModelForCausalLM.from_pretrained(model_name,load_in_8bit=True,          # 節省顯存device_map="auto"
)# 應用 LoRA
lora_config = LoraConfig(r=8,lora_alpha=16,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

準備數據集

我們使用本地 JSON 文件作為訓練數據,并轉換為 Hugging Face Dataset 格式。

from datasets import Dataset
import jsonwith open("data/reward_data.json", "r", encoding="utf-8") as f:raw_data = json.load(f)dataset = Dataset.from_list(raw_data)

使用 RewardTrainer 訓練模型

我們使用 trl 中的 RewardTrainer,它自動處理 pairwise loss(log-sigmoid ranking loss),非常適合訓練 Reward Model。

from trl import RewardTrainer, RewardConfigtraining_args = RewardConfig(output_dir="./output/rm-llama3",per_device_train_batch_size=2,gradient_accumulation_steps=4,learning_rate=1e-5,max_length=1024,num_train_epochs=3,logging_steps=10,save_strategy="epoch",remove_unused_columns=False,bf16=True,  # 或根據硬件選擇 fp16/bf16
)trainer = RewardTrainer(model=model,tokenizer=tokenizer,train_dataset=dataset,args=training_args,
)trainer.train()

保存模型

trainer.save_model("./output/rm-llama3")
tokenizer.save_pretrained("./output/rm-llama3")

保存后的模型可以直接用于 PPO、DPO 等強化學習階段,作為 reward function 評估輸出質量。

獎勵評分邏輯(原理簡述)

雖然你加載的是普通的語言模型(AutoModelForCausalLM),但 RewardTrainer 會這樣做:

  1. 輸入 prompt + chosenprompt + rejected 兩個序列

  2. 使用語言模型計算每個序列的 log-likelihood(對數似然)

  3. 總結每個序列的 log-prob 得分作為 reward 分數

  4. log(sigmoid(reward_chosen - reward_rejected)) 作為 loss,更新參數

這個過程實現了 pairwise preference learning,而你無需自定義 loss 函數。

?非lora 的方式訓練的reward 模型。

如何訓練一個 Reward Model:RLHF 的核心組件詳解_reward model訓練-CSDN博客

參考資料

https://github.com/huggingface/trl

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/90625.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/90625.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/90625.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

matrix-breakout-2-morpheus靶場攻略

靶場使用將壓縮包解壓到一個文件夾中,用虛擬機應用新建虛擬機,掃描虛擬機,掃描那個文件夾,就可以把虛擬機掃出來了,然后啟動虛擬機這時候靶場啟動后,咱們現在要找到這個靶場。靶場是網頁形式的,…

MySQL 復制表

MySQL 復制表 概述 在數據庫管理中,復制表是一項常用的操作。它允許數據庫管理員將一個表中的數據復制到另一個表中,無論是同一個數據庫還是不同的數據庫。MySQL數據庫提供了多種方法來復制表,本文將詳細介紹MySQL復制表的過程、方法及其應用…

『哈哥贈書 - 55期』-『碼農職場:IT人求職就業手冊』

文章目錄?? 碼農職場:IT人求職就業手冊?? 本書簡介?? 作者簡介?? 編輯推薦這是一本專為廣大IT行業求職者量身定制的指南,提供了從職前準備到成功就業的全方位指導,涵蓋了職業目標規劃、自我技能評估、求職策略、簡歷準備以及職場心理…

單片機學習課程

單片機學習課程 課程介紹 單片機技術作為現代工業自動化、電子電氣、通信及物聯網等領域的主流技術,早已深度融入我們生活與生產的各個角落。從常見家電到自動化公共設施,都離不開單片機的支持。同時,它也是學習 ARM 嵌入式系統、FPGA 設計等…

【AcWing 143題解】最大異或對

AcWing 143. 最大異或對 【題目描述】 在查看解析之前,先給自己一點時間思考哦! 【題解】 本題要求給定一個整數序列,找出其中任意兩個數進行異或運算后,結果的最大值是多少。由于數據規模較大,我們不能簡單地通過兩…

SQLAlchemy 2.0簡單使用

記錄一下SQLAlchemy 2.0連接mysql數據庫的方法及簡單使用 環境及依賴 Python:3.8 mysql:8.3 Flask:3.0.3 SQLAlchemy:2.0.37 PyMySQL:1.1.1使用步驟 1、創建引擎,鏈接到mysql engine create_engine(mysqlpymysql://{username}:{password}{ip}:3306/{database_name}…

如何創建或查看具有 repo 權限的 GitHub 個人訪問令牌(PAT)

要創建或查看具有 repo 權限的 GitHub 個人訪問令牌(PAT),請按照以下步驟操作: 一、生成具有 repo 權限的 PAT 登錄 GitHub 訪問 GitHub 官網,使用你的賬戶登錄。 進入開發者設置 點擊右上角頭像,選擇 Settings(設置) → 左側菜單中選擇 Developer settings(開發者設…

【AI時代速通QT】第五節:Qt Creator如何引入第三方庫,以OpenCV為例

目錄 引言 一、第一步:萬事開頭難 - 準備工作 1.1 獲取并“安裝”OpenCV 1.2 創建一個新的Qt項目 1.3 建立專業的項目目錄結構 二、第二步:核心操作 - 配置.pro文件 2.1 方式一:圖形化向導(適合初次體驗) 2.2 …

使用Clion開發STM32(Dap調試)

使用Clion開發STM32環境配置ST-Link無法下載OpenOCDST-Link調試Dap-Link調試Debug配置查看寄存器值之前寫了一篇文章關于如何用VSCode配合EIDE插件開發STM32 最近研究了如何使用Clion開發STM32 環境配置 使用Clion開發STM32需要用到4個工具:Clion、STM32CubeMX、…

人工智能-python-OpenCV 中 `release()` 和 `destroy()` 的區別

文章目錄OpenCV 中 release() 和 destroy() 的區別1. release()常見使用場景:代碼示例:作用:2. destroy()常見使用場景:代碼示例:作用:3. 總結:4. 何時使用小結:OpenCV 中 release()…

[RPA] 日期時間練習案例

案例1根據日期拆分表格根據表格中不同日期,創建多個對應日期名稱的Sheet頁(名稱格式為"yyyy-mm-dd"),并將同一日期的訂單拷貝至對應Sheet頁日期時間練習題1.xlsx流程搭建:實現效果:

2025.7.27文獻閱讀-基于深度神經網絡的半變異函數在高程數據普通克里金插值中的應用

2025.7.27周報一、文獻閱讀題目信息摘要創新點實驗一、半變異函數擬合二、普通克里金插值三、結果對比分析四、實驗結果結論不足以及展望一、文獻閱讀 題目信息 題目: Application of a semivariogram based on a deep neural network to Ordinary Kriging interp…

用unity開發教學輔助軟件---幼兒繪本英語拼讀

記錄完整項目的制作,借鑒了大佬被代碼折磨的狗子 “unity創建《找不同》游戲 圖片編輯器”一文。 (建議通過目錄閱讀本文哦~) 項目演示: 幼兒英語教輔幼兒英語繪本教學游戲整體架構 游戲開發中設計的整體框架 游戲的總體功能框架…

《Java 程序設計》第 5 章 - 數組詳解

引言在 Java 編程中,數組是一種基礎且重要的數據結構,它允許我們將多個相同類型的元素存儲在一個連續的內存空間中,通過索引快速訪問。掌握數組的使用是學習 Java 集合框架、算法等高級知識的基礎。本章將從數組的創建、使用開始,…

基于Spring Boot的可盈保險合同管理系統的設計與實現(源碼+論文)

一、相關技術 技術/工具描述SSM框架在JavaWeb開發中,SSM框架(Spring Spring MVC MyBatis)是流行的選擇。它既沒有SSH框架的臃腫,也沒有SpringMVC的簡化,屬于中間級別,更靈活且易于編寫和理解。MyBatis框…

??XSLT:XML轉換的“魔法棒”?

大家好!今天我們來聊聊 ??XSLT??(Extensible Stylesheet Language Transformations),一種用于轉換和呈現XML文檔的神奇工具。如果你曾需要將一堆枯燥的XML數據變成精美的HTML網頁、PDF報告,或其他XML格式&#xff…

面試實戰,問題十,如何保證系統在超過設計訪問量時仍能正常運行,怎么回答

如何保證系統在超過設計訪問量時仍能正常運行 在Java面試中,當被問及如何保證系統在訪問量激增(例如從100萬用戶增長到200萬)時仍能穩定運行,這是一個考察高并發、可擴展性和容錯能力的關鍵問題。核心在于通過架構設計、性能優化和…

DMDSC安裝部署教程

一、環境準備 虛擬機準備,添加共享磁盤 (1)共享存儲規劃 裸設備名 容量 用途 /dev/sdb 10 G /dev/asmdata0(數據磁盤) /dev/sdc 5 G /dev/asmdcr(DCR 磁盤) /dev/sdd 5 G /dev/asm…

半導體 CIM(計算機集成制造)系統

半導體CIM(Computer Integrated Manufacturing,計算機集成制造)系統是半導體制造的“神經中樞”,通過整合硬件設備、軟件系統和數據流轉,實現從訂單到成品的全流程自動化、信息化和智能化管理。其工作流程高度貼合半導…

AI是否會終結IT職業?深度剖析IT行業的“涌現”與重構

引言:一場不可回避的技術審判在ChatGPT、Copilot、Claude、Sora 等AI技術密集爆發的今天,IT行業首當其沖地感受到這股浪潮帶來的“智力替代壓力”。尤其是以開發、測試、運維、分析為主的崗位,逐漸被AI所“滲透”。于是,問題擺在每…