強化學習系列--dpo損失函數

DPO 概要

  1. DPO(Direct Preference Optimization,直接偏好優化)是由斯坦福大學等研究團隊于2023年提出的一種偏好優化算法,可用于LLM、VLM與MLLM的對齊訓練。

  2. 算法基于PPO的RLHF基礎上進行了大幅簡化。DPO算法跳過了訓練獎勵模型這一中間過程,直接(Direct)優化策略模型 ——這正是DPO命名中“D(Direct)”的含義所在。

主要流程

  1. 數據收集: 基于SFT訓練的模型作為推理模型,用戶輸入prompt,模型多次推理,找到好的答案和不好的答案。如果都是不好(rejected)的答案,則人工修改把不好的答案變為好的答案。

    標數據收集
  2. 主要包含兩個基礎模型,策略模型&參考模型(不需要Reward模型)。 在trl強化學習框架中,只需要傳入策略模型,參考模型會復制一份策略模型。

    1. 策略模型是DPO需要訓練的模型,后用在項目中的模型。策略模型的權重直接復制SFT階段微調模型的權重

    2. 參考模型是策略模型的幫襯,其權重參數凍結不變。主要兩個作用,其一協助其計算reward loss,其二計算kl正則項,防止其訓練偏移初始SFT模型太遠,由一個β參數控制。

  3. β參數控制含義

    1. 較大 beta(如 1.0):放大 reward 或 logp 的差異,使模型更“自信”地傾向于較優樣本,但容易過擬合或 reward 震蕩。

    2. 較小 beta(如 0.1):差異被壓縮,模型訓練更穩定,但收斂較慢、辨別力較弱。

    3. 極小 beta(趨近于 0):差異幾乎無效,模型無法區分好壞樣本,退化為隨機訓練

  4. ?整體流程如下:

  5. 具體流程

    DPO訓練流程細節

九個損失函數解析

"loss": 1.8678"rewards/chosen": 42.519317626953125"rewards/rejected": -33.865535736083984"rewards/accuracies": 0.865429699420929"rewards/margins": 76.38734436035156"logps/chosen": -948.4149780273438"logps/rejected": -1285.1175537109375"logits/chosen": 5.363300800323486"logits/rejected": 4.879658222198486
  1. logps/chosen和logps/rejected: logps 是模型生成 token 概率,在歸一化后(softmax)取 log 后的值(log prob)。

    #1 把 prompt 和 response 拼接起來作為輸入
    input = prompt + response
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch# 加載 tokenizer 和模型
    tokenizer = AutoTokenizer.from_pretrained("your-model-name")
    model = AutoModelForCausalLM.from_pretrained("your-model-name").cuda()# 設置 prompt 和 response
    prompt = "你今天心情怎么樣?"
    response = "我今天很開心,太陽出來了,我們一起去玩吧!"# 拼接輸入
    full_input = prompt + response
    encodings = tokenizer(full_input, return_tensors="pt").to("cuda")
    input_ids = encodings["input_ids"]# 找到 response 的起始位置
    prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    response_start = prompt_ids.shape[-1]# 前向推理,獲取 logits
    with torch.no_grad():outputs = model(**encodings)logits = outputs.logits# 計算 log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)# 獲取 response 部分 token 的 log probability
    response_token_ids = input_ids[:, response_start:]
    response_logits = log_probs[:, response_start - 1:-1, :]  # 對應 shift
    response_logp = torch.gather(response_logits, 2, response_token_ids.unsqueeze(-1)).squeeze(-1)# 平均 log probability(整個 response)
    logp_response = response_logp.mean()logps_chosen = compute_logp(prompt, chosen, actor_model)
    logps_rejected = compute_logp(prompt, rejected, actor_model)
    logps_ref_chosen = compute_logp(prompt, chosen, ref_model)
    logps_ref_rejected = compute_logp(prompt, rejected, ref_model)
  2. logits/chosen和logits/rejected: 模型輸出的raw score(未進行歸一化)求平均

    # 模型輸出:logits = [batch_size, seq_len, vocab_size]
    # 獲取 chosen 的最后一個 token 的 logit:
    logit_chosen = logits[:, -1, :]  # 通常是這個位置
    logits/chosen = logit_chosen.mean().item()
    # 拿出 chosen response 部分的 token 對應的 logit 向量
    logits_response = logits[:, prompt_len:, :]  # mask 掉 prompt 部分
    logits/chosen = logits_response.mean().item()
  3. reward 計算方法

    chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
    rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
    reward_accuracies = (chosen_rewards > rejected_rewards).float()
    metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
    metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
    metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
    metrics[f"{prefix}rewards/margins"] = (
    self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
  4. Loss 計算方法

    本次默認使用sigmoidlogratios = chosen_logps - rejected_logpsref_logratios = ref_chosen_logps - ref_rejected_logps                logratios = logratios.to(self.accelerator.device)ref_logratios = ref_logratios.to(self.accelerator.device)logits = logratios - ref_logratios losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)- F.logsigmoid(-self.beta * logits) * self.label_smoothing )
    其他計算方法如下(后續介紹):"hinge","ipo",
    "exo_pair","nca_pair","robust","bco_pair",
    "sppo_hard","aot","apo_down""aot_pair","apo_zero","discopop",
  5. 關系理解

    指標

    含義

    關系

    logits

    每個 token 的原始輸出分數(未歸一化)

    模型輸出的raw score(未進行歸一化)求平均

    logps

    所有 token 的 log 概率之和(對 logit softmax 后求 log,token-wise 累加)

    來自 logits → softmax → log(prob) → sum over tokens

    rewards

    在 logp-based reward 情況下,reward 就是 sum(logps)/len(tokens)

    eval_rewards/chosen == eval_logps/chosen/len(tokens)

  6. 主要關注指標

    指標名

    含義

    影響

    loss

    當前 batch 的 DPO/IPO 損失值

    反映訓練是否有效收斂,是否有發散/震蕩

    rewards/margins

    reward_chosen - reward_rejected 的平均值

    反映模型區分正負樣本的能力是否提升

    rewards/accuracies

    reward_chosen > reward_rejected 的比例

    反映偏好判斷正確率是否提高

    logs/chosen& logs/rejected

    每個 sample 的對數似然總和

    趨勢變化判斷 token-level 擬合趨勢

其他思考

1.? logps/chosen是負的合理嗎

logps(y_{chosen}|x})logps(y_{chosen}|x})?是模型對生成chosen回復時,每個token的概率取對數后加總, 由于每一個token的概率?,所以。p(yt,y<t)∈(0,1),所以logp(yt)<0。?所以累加一段文本后,整個logp通常是一個比較大的負值。

2. reward為負值

因為是?rchosen=logπθ(ychosen|x)?,如果沒有額外reward打分模型,則?r=sum(logps)/len(logps)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/87008.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/87008.shtml
英文地址,請注明出處:http://en.pswp.cn/web/87008.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

UniApp完全支持快應用QUICKAPP-以及如何采用 Uni 模式開發發行快應用優雅草卓伊凡

UniApp完全支持快應用QUICKAPP-以及如何采用 Uni 模式開發發行快應用優雅草卓伊凡 一、UniApp 對快應用的支持深度 UniApp 已完全支持快應用的開發和發布&#xff0c;具體包括&#xff1a; 兩種渲染模式&#xff1a; Webview 渲染&#xff08;快應用 Light 版&#xff09;&a…

js 允許生成特殊的變量名 基于字符集編碼混淆的 XSS 繞過漏洞 -- Google 2025 Lost In Transliteration

題目實現了一個字符轉換工具 在/file路由用戶可以通過 ct 參數自定義 Content-Type // 文件路由 - 提供靜態文件服務&#xff08;JS和CSS&#xff09;&#xff0c;支持內容類型驗證 app.MapGet("/file", (string filename "", string? ct null, string?…

【仿muduo庫實現并發服務器】LoopThreadPool模塊

仿muduo庫實現并發服務器 1.LoopThread模塊1.1成員變量1.2構造函數13線程入口函數1.4獲取eventloop對象GetLoop() 2.LoopThreadPool模塊2.1成員變量2.2構造函數2.3配置線程數量2.4按照配置數量創建線程2.5依次分配Eventloop對象 1.LoopThread模塊 這個模塊是為了將EventLoop與…

華為云Flexus+DeepSeek征文|基于Dify構建文本/圖像/視頻生成工作流

華為云FlexusDeepSeek征文&#xff5c;基于Dify構建文本/圖像/視頻生成工作流 一、構建文本/圖像/視頻生成工作流前言二、構建文本/圖像/視頻生成工作流環境2.1 基于FlexusX實例的Dify平臺2.2 基于MaaS的模型API商用服務 三、構建文本/圖像/視頻生成工作流實戰3.1 配置Dify環境…

相機-IMU聯合標定:IMU更新頻率

文章目錄 ??簡介?? IMU頻率參數錯誤設置的影響? 相機-IMU聯合標定失敗:Optimization failed!?? 確定IMU更新頻率直接通過 rostopic hz 檢查實際頻率檢查 IMU 驅動或數據手冊從 bag 文件統計頻率在這里插入圖片描述修改 `update_rate` 的注意事項**最終建議****常見問題…

動手實踐:如何提取Python代碼中的字符串變量的值

要提取Python代碼中所有變量類型為字符串的變量的值&#xff0c;但不執行代碼&#xff08;避免安全風險&#xff09;&#xff0c;可以通過靜態分析代碼的抽象語法樹&#xff08;AST&#xff09;來實現。以下是完整的解決方案&#xff1a; 本文由「大千AI助手」原創發布&#xf…

Python中字符串isalpha()函數詳解

在 Python 中&#xff0c;isalpha() 是字符串&#xff08;string&#xff09;類型的內置方法&#xff0c;用于檢查字符串中的所有字符是否都是字母字符&#xff08;alphabetic character&#xff09;。以下是詳細說明&#xff1a; 一、基本功能 返回值&#xff1a;布爾值&…

Gradio全解13——MCP詳解(4)——TypeScript包命令:npm與npx

Gradio全解13——MCP詳解&#xff08;4&#xff09;——TypeScript包命令&#xff1a;npm與npx 第13章 MCP詳解13.4 TypeScript包命令&#xff1a;npm與npx13.4.1 概念區分1. npm概念與運行邏輯2. npx概念及特點 13.4.2 操作示例1. 使用npm執行包2. 使用npx執行包3. 常用npm命令…

《推客小程序全鏈路開發指南:從架構設計到裂變運營》

在移動互聯網流量紅利逐漸消退的今天&#xff0c;如何低成本獲客成為企業營銷的核心痛點。推客小程序作為一種基于社交關系的裂變營銷工具&#xff0c;正成為企業突破增長瓶頸的利器。本文將為您全面解析推客小程序的開發定制全流程&#xff0c;幫助您打造專屬的社交裂變營銷平…

中鈞科技參加中亞數字經濟對話會,引領新疆企業數字化新征程!

6月27 日&#xff0c;烏魯木齊成為數字經濟領域的焦點&#xff0c;中國新疆 - 中亞國家數字經濟和數字貿易企業對話會在此盛大舉行。 來自中亞國家及新疆數字經濟領域的100 余位核心代表齊聚一堂&#xff0c;圍繞數字經濟時代的機遇、挑戰與策略展開深度探討。 本次對話會由新…

k8s一鍵部署tongweb企業版7049m6(by why+lqw)

聲明 1.此貼僅供參考&#xff0c;請根據自身需求在測試環境測試和修改。 安裝準備 1.獲取對應的安裝包和授權,并將授權和安裝包放在同一個目錄下 2.docekr已配置遠程倉庫 3.提前拉取jdk的鏡像&#xff08;這里配置了使用openjdk:8&#xff09; 安裝 將以下內容復制到k8s_…

Qt 與 Halcon 聯合開發六:基于海康SDK設計完整的相機類【附源碼】

在現代工業自動化、機器人視覺、等領域&#xff0c;相機模塊的作用至關重要。通過相機模塊采集到的圖像數據&#xff0c;我們能夠進行一系列的圖像處理和分析。為了高效地控制相機和處理圖像&#xff0c;本篇文章將介紹如何使用Qt和Halcon聯合開發一個相機模塊&#xff0c;幫助…

第7篇:Gin模板引擎——服務端頁面渲染

作者:GO兔 博客:https://luckxgo.cn 分享大家都看得懂的博客 引言 在Web開發中&#xff0c;服務端頁面渲染(SSR)依然是構建動態網頁的重要方式。Gin框架雖然以API開發見長&#xff0c;但也內置了強大的模板引擎支持&#xff0c;基于Go標準庫的html/template包實現。本文將深入…

RagFlow 源碼部署啟動指南

一、環境準備 1. 安裝 uv 和 pre-commit 如果已安裝&#xff0c;可跳過。推薦使用官方方式安裝&#xff0c;避免報錯&#xff1a; pipx install uv pre-commit export UV_INDEXhttps://mirrors.aliyun.com/pypi/simple安裝報錯 使用清華源安裝&#xff1a; pipx install uv…

【Python基礎】12 閑談分享:Python用于無人駕駛的未來

引言&#xff1a;一個程序員的自動駕駛夢想 還記得2016年的那個秋天&#xff0c;我第一次坐進特斯拉Model S的駕駛座&#xff0c;體驗Autopilot功能。當方向盤開始自己轉動&#xff0c;車輛在高速公路上自動跟隨前車時&#xff0c;我的內心涌起了一種奇妙的感覺——這不就是我…

為什么js是單線程?

js單線程&#xff0c;同一時間只能做一件事 。js的單線程 主要與它的用途有關。作為瀏覽器腳本語言&#xff0c;js的主要用途是與用戶互動&#xff0c;以及操作DOM。這決定了它只能是單線程&#xff0c;否則會帶來很復雜的同步問題。如果js同時有兩個線程&#xff0c;一個線程在…

DVWA靶場通關筆記-文件包含(Medium級別 9種滲透方法)

目錄 一、文件包含 1、原因 2、危害 3、防范措施 二、代碼審計&#xff08;Medium級別&#xff09; 1、滲透準備 &#xff08;1&#xff09;配置php.ini &#xff08;2&#xff09;file1.php &#xff08;3&#xff09;file2.php &#xff08;4&#xff09;file3.php…

飛云翻倍布林(翻倍密碼系統四線布林版)雙安全系統+均價趨勢指標+日線周線MACD,組合操盤技術圖文分享

如上圖組合操盤套裝指標&#xff0c;主圖指標-翻倍密碼系統四線布林版-飛云翻倍布林。副圖指標1-均價趨勢指標&#xff0c;跟蹤市場均價走勢和趨勢&#xff1b;副圖指標2-日線周線MACD指標&#xff0c;跟蹤日線和周線兩個級別的MACD多空走勢以及共振與否。 主圖指標-飛云翻倍布…

《匯編語言:基于X86處理器》第6章 條件處理(1)

本章向程序員的匯編語言工具箱中引入一個重要的內容&#xff0c;使得編寫出來的程序具備作決策的功能。幾乎所有的程序都需要這種能力。首先&#xff0c;介紹布爾操作&#xff0c;由于能影響CPU狀態標志&#xff0c;它們是所有條件指令的核心。然后&#xff0c;說明怎樣使用演繹…

【分治思想】歸并排序 與 逆序對

歸并排序 歸并排序是一種分治算法&#xff0c;怎么分&#xff0c;怎么治&#xff1f; 分&#xff1a;通過遞歸不斷把數組分成兩半&#xff0c;直到每個子數組只剩 1 個元素&#xff08;天然有序&#xff09;治&#xff1a;把兩個已經排好序的子數組合并成一個有序數組。 把問…