【多模態】DPO學習筆記

DPO學習筆記

  • 1 原理
    • 1.0 名詞
    • 1.1 preference model
    • 1.2 RLHF
    • 1.3 從RLHF到DPO
      • A.解的最優形式
      • B. DPO下參數估計
      • C. DPO下梯度更新
      • D. DPO訓練的穩定性
  • 2 源代碼
    • 2.1 數據集構成
    • 2.2 計算log prob
    • 2.3 DPO loss

1 原理

1.0 名詞

  • preference model:對人類偏好進行建模,這個"model"不是DL model
  • policy model:最終要訓練得到的LLM πθ\pi_\thetaπθ?
  • reward model:用來評價LLM生成的結果有多符合人類偏好

1.1 preference model

  • 是一種者范式、定義,是用來預測人類對不同輸出項之間相對偏好概率的模型,例如,在比較兩個響應時,偏好模型可以估計出“響應A比響應B更受歡迎”的概率
  • DPO中使用的是Bradley–Terry 模型來定義偏好的概率形式,給定2個選項ywy_wyw?yly_lyl?,Bradley–Terry 定義的的ywy_wyw?yly_lyl?好的概率為
    p(yw≥yl)=exp(θw)exp(θw)+exp(θl)p(y_w \ge y_l)=\frac{exp(\theta_w)}{exp(\theta_w)+exp(\theta_l)} p(yw?yl?)=exp(θw?)+exp(θl?)exp(θw?)?

1.2 RLHF

在這里插入圖片描述
RLHF需要使用人標注的偏好數據對,先訓練一個reward model,然后再讓reward model和LLM做強化學習
【1】SFT訓練LLM: 使用目標任務的訓練數據訓練得到的模型記為πSFT\pi^{SFT}πSFT
【2】訓練reward model: 使用目標任務的另一份數據xxx輸入πSFT\pi^{SFT}πSFT,每份數據得到2個輸出,記為(y1,y2)~πSFT(y∣x)(y_1,y_2) \sim \pi^{SFT}(y \mid x)(y1?,y2?)πSFT(yx)。這些成對的數據給到人工標注者,進行偏好標注,(y1,y2)(y_1,y_2)(y1?,y2?)里面人工覺得回答的好的數據為ywy_wyw?,覺得回答的不好的數據為yly_lyl?,得到的數據集為D={xi,ywi,yli}i=1N\mathcal{D}=\{x^{i},y^i_w,y^i_l\}^N_{i=1}D={xi,ywi?,yli?}i=1N?。假設這種偏好產生自一個隱藏的獎勵模型r?(y,x)r^*(y,x)r?(y,x),當使用Bradley-Terry模型來建模,人類偏好p?p^*p?的分布可以表示為
p?(yw?yl∣x)=exp(r?(x.y1))exp(r?(x.y1))+exp(r?(x.y2))p^*(y_w \succ y_l \mid x)=\frac{exp(r^*(x.y_1))}{exp(r^*(x.y_1))+exp(r^*(x.y_2))} p?(yw??yl?x)=exp(r?(x.y1?))+exp(r?(x.y2?))exp(r?(x.y1?))?
??可以形式化獎勵模型參數為r?(x,y)r_\phi(x,y)r??(x,y)并且使用極大似然估計在數據集D\mathcal{D}D上估計參數,建模為二分類問題,損失函數可以為(也可以是其他形式,相減比較符合認知):
LR(r?,D)=?E(x,yw,yl)~D[logσ(r?(x,yw)?r?(x,yl))]\mathcal{L}_R(r_\phi,\mathcal{D})=-\mathbb{E}_{(x,y_w,y_l)\sim\mathcal{D}}[log \sigma(r_\phi(x,y_w)-r_\phi(x,y_l))]LR?(r??,D)=?E(x,yw?,yl?)D?[logσ(r??(x,yw?)?r??(x,yl?))]

【3】RL微調: 在RL階段,優化目標帶有KL約束
max?πθEx~D,y~πθ(y∣x)[r?(x,y)?βDKL[πθ(y∣x)∥πref(y∣x)]]\max_{\pi_{\theta}}\mathbb{E}_{x \sim \mathcal{D},y \sim \pi_{\theta}(y \mid x)}[r_\phi(x,y)-\beta\mathbb{D}_{KL}[\pi_{\theta}(y \mid x)\parallel \pi_{ref}(y \mid x)]] πθ?max?ExD,yπθ?(yx)?[r??(x,y)?βDKL?[πθ?(yx)πref?(yx)]]

1.3 從RLHF到DPO

A.解的最優形式

??首先,根據RL優化目標的形式,獎勵函數為rrr,最優的策略π\piπ的形式為
πr(y∣x))=1Z(x)πref(y∣x)exp(1βr(x,y))\pi_r(y \mid x))=\frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y)) πr?(yx))=Z(x)1?πref?(yx)exp(β1?r(x,y))
其中Z(x)=∑yπref(y∣x)exp(1βr(x,y))Z(x)=\sum_{y}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y))Z(x)=y?πref?(yx)exp(β1?r(x,y))。之所以能得到這個形式在原論文的附錄中有推導
在這里插入圖片描述
??里面的第3步到第4步是因為可以引入Z(x)Z(x)Z(x)構造一個新的概率分布,Z(x)Z(x)Z(x)是歸一化因子,保證π~(y∣x)\tilde{\pi} (y \mid x)π~(yx)是有效的概率分布:
π~(y∣x)=1Z(x)πrefexp(1βr(x,y))\tilde{\pi} (y \mid x)=\frac{1}{Z(x)}\pi_{ref}exp(\frac{1}{\beta}r(x,y))π~(yx)=Z(x)1?πref?exp(β1?r(x,y))

??這樣,原來的式子
logπ(y∣x)πref(y∣x)=logπ(y∣x)?πref(y∣x)?log[exp(1βr(x,y))]=logπ(y∣x)π~(y∣x)?logZ(x)log \frac{\pi(y \mid x)}{\pi_{ref}(y \mid x)} =log\pi(y \mid x)-\pi_{ref}(y \mid x) - log[exp(\frac{1}{\beta}r(x,y))] \\ =log \frac{\pi(y \mid x)}{\tilde{\pi}_(y \mid x)} - log Z(x) logπref?(yx)π(yx)?=logπ(yx)?πref?(yx)?log[exp(β1?r(x,y))]=logπ~(?yx)π(yx)??logZ(x)

??又因π\piπ的形式只需要滿足是合法的概率分布就可以,因此形式上可以替換,以及Z(x)Z(x)Z(x)不是yyy的函數,所以期望寫進去不會對logZ(x)log Z(x)logZ(x)有影響,得到了最優策略下,策略函數的形式(給定xxx的情況下輸出yyy的概率 / 在給定狀態SSS的情況下,下一個時間的進入狀態S′S'S的概率)
π?(y∣x)=1Z(x)πref(y∣x)exp(1βr(x,y))\pi^*(y \mid x)= \frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta} r(x,y)) π?(yx)=Z(x)1?πref?(yx)exp(β1?r(x,y))
在這里插入圖片描述

B. DPO下參數估計

  • 即使得到了最優策略πr\pi_rπr?的形式,并且即使把里面的r(x,y)r(x,y)r(x,y)用MLE估計的rrr來替換,里面也有一個Z(x)Z(x)Z(x)需要估計,Z(x)Z(x)Z(x)的計算是很復雜的,里面的"狀態"或者說詞表yyy很大的情況下開銷大
  • 但是可以進一步把式子整理一下,重新表示一下reward函數
    r(x,y)=βlogπr(y∣x)πref(y∣x)+βlogZ(x)r(x,y)=\beta log \frac{\pi_r(y \mid x)}{\pi_{ref}(y \mid x)}+ \beta log Z(x)r(x,y)=βlogπref?(yx)πr?(yx)?+βlogZ(x)
  • 帶入原始的Bradley-Terry的式子,會發現,最后衡量偏好的函數里面,沒有reward function Z(x)Z(x)Z(x)這一項需要計算了抵消掉了

在這里插入圖片描述

  • 所以DPO的目標是提升yw?yly_w \succ y_lyw??yl?的概率,損失函數的形式為
    LDPO(πθ;πref)=?E(x,yw,wl)~D[logσ(βlogπθ(yw∣x)πref(yw∣x)?βlogπθ(yl∣x)πref(yl∣x))]\mathcal{L}_{DPO}(\pi_\theta;\pi_{ref}) = -\mathbb{E}_{(x,y_w,w_l)\sim \mathcal{D}}[log \sigma(\beta log \frac{\pi_\theta(y_w \mid x)}{\pi_{ref}(y_w \mid x)} - \beta log \frac{\pi_\theta(y_l \mid x)}{\pi_{ref}(y_l \mid x)}) ] LDPO?(πθ?;πref?)=?E(x,yw?,wl?)D?[logσ(βlogπref?(yw?x)πθ?(yw?x)??βlogπref?(yl?x)πθ?(yl?x)?)]

C. DPO下梯度更新

在這里插入圖片描述

  • 和人類偏好差異越大的,前面的系數越大

D. DPO訓練的穩定性

在這里插入圖片描述

  • 第二項為歸一化項是常數是因為對當前xxx,遍歷了所有的yyy
  • 減少極端值的影響:通過指數加權平均,極端值的影響會被削弱,從而使得獎勵函數更加平滑
  • 穩定梯度估計:由于獎勵函數變得更加平滑,策略梯度的估計也會更加穩定,方差會顯著減小

2 源代碼

RLAIF-V:https://github.com/RLHF-V/RLAIF-V/tree/main

2.1 數據集構成

  • chose——人類偏好的回答
  • rejected——SFT階段的模型回答
  • ref_win_logp——人類偏好回答的所有token的log_probability之和
  • ref_rej_logp——模型回答的的所有token的log_probability之和
  • ref_win_avg_logp——人類偏好回答的所有token的log_probability之和 / 回答長度的token數
data_dict = {'image': image,"question": question,"chosen": chosen,"rejected": rejected,"idx": sample['idx'],"metainfo": metainfo
}
logps=json.loads(sample['logps']) # 調用/muffin下面的./eval/muffin_inference_logp.pyif type(logps) == type([]):(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps
else:(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps['logps']return data_dict

2.2 計算log prob

def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False, tokenizer=None) -> torch.FloatTensor:"""Compute the log probabilities of the given labels under the given logits.Args:logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)Returns:A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits."""assert logits.shape[:-1] == labels.shape, f'logits.shape[:-1]={logits.shape[:-1]}, labels.shape={labels.shape}'labels = labels[:, 1:].clone()logits = logits[:, :-1, :]loss_mask = (labels != -100)# dummy token; we'll ignore the losses on these tokens laterlabels[labels == -100] = 0per_token_logps = torch.gather(logits.log_softmax(-1), dim=2,index=labels.unsqueeze(2)).squeeze(2) # get log probabilities for each token in labelslog_prob = (per_token_logps * loss_mask).sum(-1)average_log_prob = log_prob / loss_mask.sum(-1)

2.3 DPO loss

  • policy model指的是正在訓練的模型,ref model是之前SFT階段的模型
  • 注意policy_chosen_logps這些是log 的probability,所以和原始的DPO的loss公式是完全等價的
def get_beta_and_logps(data_dict, model, args, is_minicpm=False, is_llava15=False):win_input_ids = data_dict.pop('win_input_ids')rej_input_ids = data_dict.pop('rej_input_ids')ref_win_logp = data_dict.pop('ref_win_logp')ref_rej_logp = data_dict.pop('ref_rej_logp')log_prob, average_log_prob = get_batch_logps(output.logits, concatenated_labels, return_per_token_logp=False)if args.dpo_use_average:concatenated_logp = average_log_probwin_size = win_input_ids.shape[0]rej_size = rej_input_ids.shape[0]policy_win_logp, policy_rej_logp = concatenated_logp.split([win_size, rej_size])  # 默認的是average的log_logits,值越大越置信return policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, betadef dpo_loss(policy_chosen_logps: torch.FloatTensor,policy_rejected_logps: torch.FloatTensor,reference_chosen_logps: torch.FloatTensor,reference_rejected_logps: torch.FloatTensor,beta: float,reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:"""Compute the DPO loss for a batch of policy and reference model log probabilities.Args:policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.Returns:A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).The losses tensor contains the DPO loss for each example in the batch.The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively."""pi_logratios = policy_chosen_logps - policy_rejected_logps  # log(\pi(a_i | x)) - log(\pi(b_i | x)) = log(\pi(a_i | x) / \pi(b_i | x))ref_logratios = reference_chosen_logps - reference_rejected_logps  # 完全等價的if reference_free:ref_logratios = 0logits = pi_logratios - ref_logratioslosses = -F.logsigmoid(beta * logits)chosen_rewards = beta * (policy_chosen_logps -reference_chosen_logps).detach()rejected_rewards = beta * \(policy_rejected_logps - reference_rejected_logps).detach()return losses, chosen_rewards, rejected_rewards############# 調用為policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, beta = get_beta_and_logps(data_dict, model, self.args, is_llava15=True) # 這些都是averaged的token的log_logitslosses, chosen_rewards, rejected_rewards = dpo_loss(policy_win_logp,policy_rej_logp,ref_win_logp,ref_rej_logp,beta=beta)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/92129.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/92129.shtml
英文地址,請注明出處:http://en.pswp.cn/web/92129.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025最新、UI媲美豆包、DeepSeek等AI大廠的AIGC系統 - IMYAI源碼部署教程

IMYAI 系統部署與使用手冊 一、系統演示 🔹 快速體驗 前端演示地址:https://super.imyaigc.com后臺演示地址:https://super.imyaigc.com/settings 🔹 技術架構 前端:Vite Vue3 NaiveUI TailwindCSS Plyr后端&…

【關于Java的反射】

在 Java 編程中,反射(Reflection) 是一個非常強大的工具,它允許你在運行時動態地獲取類的信息、創建對象、調用方法和訪問字段。雖然反射功能強大,但它也有一些局限性和性能開銷,因此需要謹慎使用。一、什么…

Gitee推出“移動軟件工廠“解決方案 解決嵌入式與涉密場景研發困局

Gitee推出"移動軟件工廠"解決方案 破解嵌入式與涉密場景研發困局 隨著數字化轉型浪潮的推進,軟件開發正面臨著前所未有的復雜環境挑戰。特別是在嵌入式系統、FPGA開發以及涉密信息系統等特殊場景下,研發團隊往往需要在高安全要求與有限網絡環境…

低功耗16*8位四線串行8*4按鍵陣矩LED驅動專用電路

概述:PC0340是占空比可調的LED顯示控制驅動電路。由16根段輸出、8根位輸出、數字接口、數據鎖存器、顯示存儲器、鍵掃描電路及相關控制電路組成了一個高可靠性的單片機外圍LED驅動電路。串行數據通過4線串行接口輸入到PC0340,采用LQFP44L的封裝形式。本產…

通過自定義注解加aop切面實現權限控制

前言:自定義注解,通過aop切面前置通知,對請求接口進行權限控制1,創建枚舉類package org.springblade.sample.annotationCommon;import lombok.AllArgsConstructor; import lombok.Getter;import java.util.Arrays; import java.ut…

IDS知識點

在網絡安全工程師、系統運維工程師等崗位的面試中,??IDS(Intrusion Detection System,入侵檢測系統)?? 是高頻考點,尤其是對網絡安全防護、安全監控類崗位。以下是IDS的核心考點和必須掌握的知識點,按優…

Adobe Analytics 數據分析平臺|全渠道客戶行為分析與體驗優化

Adobe Analytics 是業界領先的數據分析平臺,幫助企業實時追蹤客戶行為,整合多渠道數據,通過強大的分析與可視化工具深入分析客戶旅程,優化數字體驗。結合 Adobe Experience Cloud,Adobe Analytics 成為推動數字化增長和…

【輪播圖】H5端輪播圖、橫向滑動、劃屏效果實現方案——Vue3+CSS position/CSS scroller

文章目錄定位實現滑屏效果前置知識CSS: touch-action屬性CSS: transform屬性觸摸事件forEach回調占位符準備階段實現移動效果實現跟手效果觸摸結束優化完整代碼滾動實現滑屏效果前置知識CSS: scroll-snap-type屬性準備階段實現滑動效果實現吸附效果滾動條隱藏存在問題完整代碼s…

忘記了WordPress管理員密碼的找回方法

WordPress管理員密碼找回方法 如果您忘記了WordPress管理員密碼,可以通過以下幾種方法找回或重置: 方法1:通過電子郵件重置(最簡單) 訪問您的WordPress登錄頁面(通常是wodepress.com/wp-admin或wodepress.com/wp-login.php) 點擊”忘記密…

RAFT:讓語言模型更聰明地用文檔答題

RAFT:讓語言模型更聰明地用文檔答題 作者注: 本文旨在面向零基礎讀者介紹 UC Berkeley 提出的 RAFT(Retrieval-Augmented Fine-Tuning)方法。它是一種訓練語言模型的新方式,讓模型更好地利用“外部知識”——比如文檔、…

【緊急預警】NVIDIA Triton推理服務器漏洞鏈可導致RCE!

2025 年 8 月 4 日消息,NVIDIA 旗下的 Triton 推理服務器(一款支持 Windows 和 Linux 系統、用于大規模運行 AI 模型的開源平臺)被曝出一系列安全漏洞。這些漏洞一旦被利用,攻擊者有可能完全接管存在漏洞的服務器。 Wiz 安全公司…

基于深度學習的醫學圖像分析:使用PixelCNN實現醫學圖像生成

前言 醫學圖像分析是計算機視覺領域中的一個重要應用,特別是在醫學圖像生成任務中,深度學習技術已經取得了顯著的進展。醫學圖像生成是指通過深度學習模型生成醫學圖像,這對于醫學研究、疾病模擬和圖像增強等任務具有重要意義。近年來&#x…

React ahooks——副作用類hooks之useDebounceFn

useDebounceFn 是 ahooks 提供的用于函數防抖的 Hook,它可以確保一個函數在連續觸發時只執行最后一次。一、基本用法import { useDebounceFn } from ahooks; import { Button } from antd;const Demo () > {const { run } useDebounceFn(() > {console.log(…

【機器學習深度學習】 知識蒸餾

目錄 前言 一、什么是知識蒸餾? 二、知識蒸餾的核心意義 2.1 降低算力與成本 2.2 加速推理與邊緣部署 2.3 推動行業應用落地 2.4 技術自主可控 三、知識蒸餾的本質:大模型的知識傳承 四、知識蒸餾的“四重紅利” 五、DeepSeek的知識蒸餾實踐 …

Python高級編程與實踐:Python高級數據結構與編程技巧

高級數據結構:掌握Python中的高效編程技巧 學習目標 通過本課程,學員將深入了解Python中的高級數據結構,包括列表推導式、字典推導式、集合推導式和生成器表達式。學員將學習如何利用這些結構來編寫更簡潔、更高效的代碼,并了解它…

【C++】Stack and Queue and Functor

本文是小編鞏固自身而作,如有錯誤,歡迎指出!本次我們介紹STL中的stack和queue和其相關的一些容器和仿函數一.stack and queue1.適配器stack和queue其實不是真正意義上的容器,而是容器適配器,而容器適配器又是什么呢&am…

Python爬蟲實戰:研究OpenCV技術構建圖像數據處理系統

1. 引言 1.1 研究背景 在當今數字化時代,圖像作為一種重要的信息載體,廣泛存在于各類網站、社交媒體和在線平臺中。這些圖像數據涵蓋了從自然風光、人物肖像到商品展示、新聞事件等豐富內容,為數據分析和模式識別提供了寶貴的資源。隨著計算機視覺技術的快速發展,對大規模…

電感矩陣-信號完整性分析

電感矩陣:正如電容矩陣用于存儲許多信號路徑和返回路徑的所有電容量,我們也需要一個矩陣存儲許多導線的回路自感和回路互感值。需要牢記的是,這里的電感元件是回路電感。當信號沿傳輸線傳播時,電流回路沿信號路徑傳輸,然后立即從返…

JUC相關知識點總結

Java JUC(java.util.concurrent)是Java并發編程的核心工具包,提供了豐富的并發工具類和框架。以下是JUC的主要知識點,按難易程度分類,供你參考: 1. 基礎概念與工具類 1.1 并發與并行(易&#x…

激光頻率梳 3D 測量方案革新:攻克光學掃描遮擋,130mm 深孔測量精度達 2um

一、深孔測量的光學遮擋難題在精密制造領域,130mm 級深孔(如航空發動機燃油孔、模具冷卻孔)的 3D 測量長期受困于光學遮擋。傳統激光掃描技術依賴直射光束,當深徑比超過 10:1 時,孔壁中下部形成大量掃描盲區&#xff0…