RLHF獎勵模型的訓練

由于 RLHF 的訓練過程中需要依賴大量的人類偏好數據進行學習,因此很難在訓練過程中要求人類標注者實時提供偏好反饋。為此,我們需要訓練一個模型來替代人類在 RLHF 訓練過程中實時提供反饋,這個模型被稱為獎勵模型

🔸一、 目標函數公式解釋

公式如下:

L = ? E ( x , y + , y ? ) ~ D [ log ? σ ( r θ ( x , y + ) ? r θ ( x , y ? ) ) ] ? β E ( x , y + ) ~ D [ ∑ t = 1 T log ? p ( y t + ∣ x , y < t + ) ] L = -\mathbb{E}_{(x, y^+, y^-) \sim D} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right] - \beta \mathbb{E}_{(x, y^+)\sim D} \left[ \sum_{t=1}^{T} \log p(y^+_t \mid x, y^+_{<t}) \right] L=?E(x,y+,y?)D?[logσ(rθ?(x,y+)?rθ?(x,y?))]?βE(x,y+)D?[t=1T?logp(yt+?x,y<t+?)]

含義拆解:

  • x: 輸入(如問題或提示語)
  • y+: 正例響應(由人類標注或偏好選擇的答案)
  • y-: 負例響應(不好的答案)
  • r_θ(x, y): 獎勵模型對 (x, y) 的打分(通常是最后一個 token 的輸出經過 reward head 得到)
  • σ: Sigmoid 函數
  • β: 權重超參,控制模仿學習(第二項)對總損失的影響程度

公式兩部分含義:

  1. 對比損失(ranking loss)

    $$

    • \log \sigma(r(x, y^+) - r(x, y^-))
      $$
    • 目標是使 正例得分 > 負例得分
    • r(x, y+) ? r(x, y-) 時,sigmoid接近1,log接近0 → 損失小,說明模型學得好
  2. 模仿學習損失(語言模型 loss)

    $$

    • \sum_{t=1}^{T} \log p(y^+t \mid x, y^+{<t})
      $$
    • 即:語言模型在給定輸入 x 和前綴 y^+_{<t} 的條件下,預測下一個 token 的交叉熵損失
    • 起正則作用,防止獎勵模型過度擬合打分而喪失語言生成能力

🔸二、代碼結構分析

基于 LLaMA 的獎勵模型實現詳解(逐行解讀 + PyTorch 源碼分析)

📦 模塊導入

1  import torch
2  import torch.nn as nn
3  import torch.nn.functional as F
4
5  from transformers import LlamaForCausalLM
  • torch:PyTorch 核心包
  • nn:用于定義神經網絡模塊(如 Linear)
  • F:包含函數式接口(如 loss 函數)
  • LlamaForCausalLM:來自 Transformers 的 LLaMA 語言模型基類,支持自回歸文本生成

🧠 模型定義:獎勵模型類

7  class LlamaRewardModel(LlamaForCausalLM):
8      def __init__(self, config):
9          super().__init__(config)
10
11         # 初始化線性變換層,將隱狀態映射為標量,用于輸出最終獎勵
12         self.reward_head = nn.Linear(config.hidden_size, 1, bias=False)
  • LlamaRewardModel 繼承自 HuggingFace 的 LlamaForCausalLM
  • 增加了一個 reward_head 線性層,用于將模型輸出(hidden state)映射為 獎勵值(scalar)

🧾 正例/負例打分函數 _forward_rmloss

14 def _forward_rmloss(self, input_ids, attention_mask, **kargs):
18     output = self.model.forward(
19         input_ids=input_ids,
20         attention_mask=attention_mask,
21         return_dict=True,
22         use_cache=False
23     )
25     logits = self.reward_head(output.last_hidden_state).squeeze(-1)
26     return logits
  • 輸入:拼接后的 [x, y] 序列
  • self.model.forward(...):獲得 LLaMA 模型輸出(hidden states)
  • self.reward_head(...):只對最后一層 hidden state 應用線性映射,輸出獎勵值
  • squeeze(-1):去除最后一維 [batch, 1] -> [batch]

squeeze(-1) 的作用是去掉張量的最后一個維度,前提是該維度的值是 1。
假設 logits 是一個 [batch_size, 1] 的張量:
logits = tensor([[0.73], [0.24], [0.91]]) # shape: [3, 1]
執行:
logits = logits.squeeze(-1)
結果為:
tensor([0.73, 0.24, 0.91]) # shape: [3]

?? 模仿學習損失函數 _forward_lmloss

29 def _forward_lmloss(self, prompt_ids, lm_attn_mask, response_ids):
35     outputs = self.model.forward(
36         input_ids=prompt_ids,
37         attention_mask=lm_attn_mask,
38         return_dict=True,
39         use_cache=False,
40     )
42     hidden_states = outputs.last_hidden_state
43     logits = self.lm_head(hidden_states)
44     loss_fct = nn.CrossEntropyLoss()
45     logits = logits.view(-1, self.config.vocab_size)
46     response_ids = response_ids.view(-1)
47     loss = loss_fct(logits, response_ids)
48     return loss
  • prompt_ids[x, y?] 拼接后的 token ID
  • 輸出 logits:維度 [batch_size, seq_len, vocab_size]
  • 計算交叉熵損失:對所有位置預測的 token 與 response_ids 進行對比

🚀 前向傳播函數:組合損失計算

50 def forward(self, sent1_idx, attention_mask_1, sent2_idx,attention_mask_2, labels, prompt_ids, lm_attn_mask, response_ids):

參數說明:

  • sent1_idx: [x, y?] 拼接輸入(正例)
  • sent2_idx: [x, y?] 拼接輸入(負例)
  • labels: 全 0 標簽,用于對比損失
  • prompt_ids: 與正例相關的 token(用于 LM Loss)
  • response_ids: 正例的 target token(用于 LM Loss)

計算對比損失(Reward Loss)

61 reward0 = self._forward_rmloss(sent1_idx, attention_mask_1)
66 reward1 = self._forward_rmloss(sent2_idx, attention_mask_2)
71 logits = reward0 - reward1
72 rm_loss = F.binary_cross_entropy_with_logits(logits,labels.to(logits.dtype), reduction="mean")
  • 分別計算 r(x, y?)r(x, y?)

  • 構造 logits = r? - r?

  • 用 Binary Cross Entropy Loss 計算 reward loss

    公式對應:
    ? log ? ( σ ( r ( x , y + ) ? r ( x , y ? ) ) ) -\log(\sigma(r(x, y?) - r(x, y?))) ?log(σ(r(x,y+)?r(x,y?)))


計算語言模型損失(Language Modeling Loss)

75 lm_loss = self._forward_lmloss(prompt_ids, lm_attn_mask, response_ids)
  • 與傳統語言模型訓練一致,使用 CrossEntropyLoss

返回總損失

78 loss = rm_loss + lm_loss
79 return loss
  • 二者直接加和(可選加權項 β,可自己加參數)
  • 模型即同時優化打分能力 + 文本生成能力(聯合學習)

🔸四、總結

項目描述
核心思想同時學習獎勵模型 r_θ 和保持生成流暢性
優勢1. 保留強化學習能力
2. 不失語義與流暢性
應用場景RLHF 的 reward 模型訓練階段,如 OpenAI 的 GPT 訓練流程中 Step 2: Train Reward Model
可調參數β 控制生成質量與偏好打分之間的權衡

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/85178.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/85178.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/85178.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

reverse_ssh 建立反向 SSH 連接指南 混淆AV [好東西喲]

目錄 &#x1f310; 工具簡介 ?? 前提條件 攻擊主機 (Linux) 目標主機 (Windows) &#x1f4cb; 詳細步驟 步驟 1&#xff1a;安裝 Go 環境 步驟 2&#xff1a;安裝必要依賴 步驟 3&#xff1a;下載并編譯 reverse_ssh 步驟 4&#xff1a;配置密鑰 步驟 5&#xff…

Ubuntu 下搭建ESP32 ESP-IDF開發環境,并在windows下用VSCode通過SSH登錄Ubuntu開發ESP32應用

Ubuntu 下搭建ESP32 ESP-IDF開發環境&#xff0c;網上操作指南很多&#xff0c;本來一直也沒有想過要寫這么一篇文章。因為我其實不太習慣在linux下開發應用&#xff0c;平時更習慣windows的軟件操作&#xff0c;只是因為windows下開發ESP32的應用編譯時太慢&#xff0c;讓人受…

Rust使用Cargo構建項目

文章目錄 你好&#xff0c;Cargo&#xff01;驗證Cargo安裝使用Cargo創建項目新建項目配置文件解析默認代碼結構 Cargo工作流常用命令速查表詳細使用說明1. 編譯項目2. 運行程序3.快速檢查4. 發布版本構建 Cargo的設計哲學約定優于配置工程化優勢 開發建議1. 新項目初始化?2. …

免費且好用的PDF水印添加工具

軟件介紹 琥珀掃描.zip下載鏈接&#xff1a;https://pan.quark.cn/s/3a8f432b29aa 今天要給大家推薦一款超實用的PDF添加水印工具&#xff0c;它能夠滿足用戶給PDF文件添加水印的需求&#xff0c;而且完全免費。 這款PDF添加水印的軟件有著簡潔的界面&#xff0c;操作簡便&a…

NW969NW978美光閃存顆粒NW980NW984

NW969NW978美光閃存顆粒NW980NW984 技術解析&#xff1a;NW969、NW978、NW980與NW984的架構創新 美光&#xff08;Micron&#xff09;的閃存顆粒系列&#xff0c;尤其是NW969、NW978、NW980和NW984&#xff0c;代表了存儲技術的前沿突破。這些產品均采用第九代3D TLC&#xf…

Mysql常用知識3:Kafka和數據庫優化

文章目錄 一、分布式消息系統&#xff08;Kafka相關問題5-10&#xff09;5. Kafka如何保證消息不丟失&#xff1f;6. 項目中Kafka具體怎么使用的&#xff1f;7. 消息異常未發送成功怎么解決&#xff1f;8. 重試具體怎么做的&#xff0c;循環嗎&#xff1f;9. 重試多次失敗怎么辦…

常見的RAG文檔解析輔助工具匯總及企業選型思考

以下當前比較知名的RAG的文檔解析輔助工具的開源項目匯總&#xff0c;包含核心功能、License信息及GitHub地址&#xff1a; 1. RAGFlow 核心功能&#xff1a;支持PDF/掃描件/CAD等23種格式解析&#xff0c;OCR準確率98%&#xff0c;知識圖譜融合&#xff0c;混合檢索&#xf…

基于Sqoop的MySQL-Hive全量/增量同步解決方案(支持多表批量處理

一、全量同步方案設計 1.1 基礎命令模板 sqoop import \ --connect jdbc:mysql://mysql_host:3306/db_name \ --username user \ --password pass \ --table source_table \ --hive-import \ --hive-table target_table \ --hive-overwrite \ # 覆蓋已有表 --num-mappers 8 …

前端學習(7)—— HTML + CSS實現博客系統頁面

目錄 一&#xff0c;效果展示 二&#xff0c;實現博客列表頁 2.1 實現導航欄 2.2 實現個人信息 2.3 實現博客列表 三&#xff0c;實現博客正文頁 3.2 復用 3.4 實現博客正文 四&#xff0c;實現博客登錄頁 4.1 版心 4.2 登錄框 五&#xff0c;實現博客編輯頁 5.1 …

【技能拾遺】——家庭寬帶單線復用布線與配置(移動2025版)

&#x1f4d6; 前言&#xff1a;在家庭網絡拓撲中&#xff0c;客廳到弱電箱只預埋了一根網線&#xff0c;由于已將廣電的有線電視取消并改用IPTV。現在需要解決在客廳布置路由器和觀看IPTV問題&#xff0c;這里就用到單線復用技術。 目錄 &#x1f552; 1. 拓撲規劃&#x1f55…

VTK|實現類似CloundCompare的測量功能

文章目錄 CloundCompare在點、線、面三種模式下的顯示內容? 圖1&#xff1a;點模式? 圖2&#xff1a;線模式? 圖3&#xff1a;面模式 增加控制菜單欄實現測量功能類如何調用項目git鏈接 CloundCompare在點、線、面三種模式下的顯示內容 點 線 面 三張圖展示了 CloudComp…

4000萬日訂單背后,餓了么再掀即時零售的“效率革命”

當即時零售轉向價值深耕&#xff0c;贏面就是綜合實力的強弱。 文&#xff5c;郭夢儀 編&#xff5c;王一粟 在硝煙彌漫的外賣行業“三國殺”中&#xff0c;餓了么與淘寶閃購的日訂單量竟然突破了4000萬單。 而距淘寶閃購正式上線&#xff0c;還不到一個月。 在大額福利優惠…

vedio.ontimeupdate()和video.onloadeddata()

video.onloadeddata &#xff08;&#xff09; video.onloadeddata 是 JavaScript 中用于監聽 HTML <video> 元素 「當前幀數據已加載」 的事件處理器。當視頻的第一幀畫面數據加載完成&#xff08;足以開始播放&#xff09;時&#xff0c;會觸發此事件。 1. 基本用法 …

Baklib內容中臺革新企業知識實踐

Baklib智能知識中樞構建 作為現代企業知識管理的核心架構&#xff0c;Baklib內容中臺通過整合多源異構數據形成智能化知識中樞&#xff0c;實現從信息采集到價值轉化的全鏈路管理。其底層采用跨平臺數據貫通技術&#xff0c;支持API接口與企業現有CRM、ERP系統無縫對接&#x…

用不太嚴謹的文字介紹遙測自跟蹤天線的基本原理

前兩天跟一個客戶見面的時候&#xff0c;客戶問我&#xff1a;遙測自跟蹤天線能夠跟蹤目標&#xff0c;是什么原理&#xff1f;不需要目標的位置&#xff0c;怎么做到自跟蹤的&#xff1f; 突然一瞬間&#xff0c;有點語塞。 難道要介紹天線、饋源、極化、左旋、右旋、和差網…

VS配置redis環境、redis簡單封裝

一、安裝redis數據庫 1.下載redis的壓縮包 wget https://download.redis.io/releases/redis-6.0.5.tar.g 2.解壓縮redis壓縮包&#xff0c;一般就在當前路徑 tar -zvxf redis-6.0.5.tar.gz -C /usr/local/redis 方便找我把它解壓縮在/usr/local/redis&#xff0c;如果沒有r…

C++23 已移除特性解析

文章目錄 引言C23 已移除特性介紹1. 垃圾收集的支持和基于可達性的泄漏檢測&#xff08;P2186R2&#xff09;背景與原理存在的問題移除的影響 2. 混合寬字符串字面量拼接非良構&#xff08;P2201R1&#xff09;寬字符串編碼概述混合拼接的問題示例分析移除的意義 3. 不可編碼寬…

Cloudflare

Cloudflare 是一個網絡基礎設施和網站安全服務提供商&#xff0c;它的主要作用是讓網站 更快、更安全、更可靠。簡單來說&#xff0c;它是一個“護盾 加速器”。 &#x1f9e9; Cloudflare 的主要功能&#xff1a; 1. &#x1f680; 加速網站訪問&#xff08;CDN&#xff09…

Spring Boot啟動慢?Redis緩存擊穿?Kafka消費堆積?——Java后端常見問題排查實戰

Spring Boot啟動慢&#xff1f;Redis緩存擊穿&#xff1f;Kafka消費堆積&#xff1f;——Java后端常見問題排查實戰 引言 Java后端系統因其豐富的技術棧和復雜的業務邏輯&#xff0c;常常面臨啟動延遲、性能瓶頸、異常錯誤等多種挑戰。從核心語言、Web框架到分布式微服務及緩…

數字人引領政務新風尚:智能設備助力政務服務

在信息技術飛速發展的今天&#xff0c;政府機構不斷探索提升服務效率和改善服務質量的新途徑。實時交互數字人在政務服務中的應用正成為一大亮點&#xff0c;通過將“數字公務員”植入各種橫屏智能設備中&#xff0c;為民眾辦理業務提供全程輔助。這種創新不僅優化了政務大廳的…