【復現DeepSeek-R1之Open R1實戰】系列6:GRPO源碼逐行深度解析(上)

目錄

  • 4 GRPO源碼分析
    • 4.1 數據類 `GRPOScriptArguments`
    • 4.2 系統提示字符串 `SYSTEM_PROMPT`
    • 4.3 獎勵函數
      • 4.3.1 accuracy_reward函數
      • 4.3.2 verify函數
      • 4.3.3 format_reward函數
    • 4.4 將數據集格式化為對話形式
    • 4.5 初始化GRPO Trainer


【復現DeepSeek-R1之Open R1實戰】系列3:SFT和GRPO源碼逐行深度解析(上)
【復現DeepSeek-R1之Open R1實戰】系列5:SFT和GRPO源碼逐行深度解析(中)

4 GRPO源碼分析

前面兩篇博文已經詳細介紹了一些基礎知識和SFT源碼,本文繼續解讀GRPO源碼。與SFT源碼差不多的部分,我們就不展開細說了,這里只解析GRPO獨特的部分。

4.1 數據類 GRPOScriptArguments

該類使用了 Python 的 dataclass 裝飾器,這是一種簡化類定義的方式,特別是對于那些主要用來存儲數據的類。它繼承自 ScriptArguments 類。

  • reward_funcs: 這是一個列表,包含了一系列可能的獎勵函數名稱,默認值為 ["accuracy", "format"]。這些獎勵函數可能是用于評估模型性能的不同標準。

    reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"},
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分別表示錯誤答案在余弦相似度尺度上的最小和最大獎勵值,默認分別為 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分別表示正確答案在余弦相似度尺度上的最小和最大獎勵值,默認分別為 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大長度,默認值為 1000

  • repetition_n_grams: 表示用于重復懲罰獎勵的n-gram數量,默認值為 3

  • repetition_max_penalty: 表示重復懲罰獎勵的最大負值,默認值為 -1.0

每個字段都使用了 field() 函數來定義其默認值和元數據(如幫助信息)。這有助于工具和庫更好地理解和處理這些字段,例如生成命令行解析器時。

4.2 系統提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一個對話場景,用戶先提問,助手首先思考推理過程,然后提供答案。推理過程和答案分別用 <think><answer> 標簽包裹,這種格式化有助于區分和識別不同的部分,和DeepSeek-R1的思考過程格式一致。

4.3 獎勵函數

獎勵函數的定義如下,GRPO默認用到了accuracy_reward和format_reward這兩個函數。

# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

這段代碼定義了一個獎勵函數注冊表 REWARD_FUNCS_REGISTRY,并根據用戶提供的配置動態生成一個獎勵函數列表 reward_funcs。每個獎勵函數用于評估模型輸出的不同方面,如準確性、格式、推理步驟等。

  1. 注冊表定義
  • accuracy: 使用 accuracy_reward 函數評估模型輸出的準確性。
  • format: 使用 format_reward 函數評估模型輸出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函數評估模型輸出的推理步驟。
  • cosine: 使用 get_cosine_scaled_reward 函數計算余弦相似度獎勵,參數包括:
    • min_value_wrong: 錯誤情況下的最小值。
    • max_value_wrong: 錯誤情況下的最大值。
    • min_value_correct: 正確情況下的最小值。
    • max_value_correct: 正確情況下的最大值。
    • max_len: 最大長度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函數計算重復懲罰獎勵,參數包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大懲罰值。
  • length: 使用 len_reward 函數評估模型輸出的長度。
  1. 動態生成獎勵函數列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根據 script_args.reward_funcs 中指定的獎勵函數名稱,從 REWARD_FUNCS_REGISTRY 中獲取相應的獎勵函數,并生成一個列表 reward_funcs

4.3.1 accuracy_reward函數

該函數用于計算模型生成的補全與真實答案之間的準確性獎勵。它通過解析和驗證生成的內容與真實答案來確定獎勵值。

def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
  • completions (list): 包含多個補全結果的列表,每個補全結果是一個包含內容的字典列表。
  • solution (list): 真實答案的列表。
  • kwargs: 其他可選參數(在本函數中未使用)。
  1. 提取補全內容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每個補全的第一個內容(假設每個補全是單個元素的列表),形成一個新的 contents 列表。
  2. 初始化獎勵列表

    rewards = []
    
  3. 遍歷每個補全和對應的真實答案

    for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
    
    • 使用 zip 函數將 contentssolution 配對。
    • 對于每一對補全內容和真實答案,首先解析真實答案 sol,使用 parse 函數提取其中的信息。
  4. 處理解析結果

    if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
    
    • 如果解析得到的真實答案 gold_parsed 非空,則繼續解析生成的補全內容 content
    • 使用 LatexExtractionConfigNormalizationConfig 進行詳細配置,確保解析過程中考慮了各種格式要求(如方程、單位等)。
  5. 計算獎勵

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函數比較生成的補全解析結果和真實答案的解析結果。
    • 如果兩者匹配,則返回 1.0,否則返回 0.0
  6. 處理無法解析的情況

    else:reward = 1.0print("Failed to parse gold solution: ", sol)
    
    • 如果真實答案無法解析,則默認給予獎勵 1.0 并打印警告信息。
  7. 添加獎勵到列表

    rewards.append(reward)
    
  8. 返回所有獎勵

    return rewards
    

4.3.2 verify函數

該函數用于驗證目標表達式是否與參考表達式匹配,它通過多種比較策略來處理不同的數學對象(如數字、表達式、集合、矩陣等),并提供靈活的配置選項以適應不同的需求。

def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
  • gold: 參考或正確的表達式,可以是單個 SymPy 表達式(BasicMatrixBase)、字符串或這些類型的列表。
  • target: 需要驗證的表達式,類型同 gold
  • float_rounding: 浮點數舍入的小數位數,默認為 6。
  • numeric_precision: 數值比較時考慮的小數位數,默認為 15。
  • strict: 是否啟用嚴格比較模式,默認為 True
    • 在嚴格模式下:變量很重要,集合不可與元組比較。
    • 在非嚴格模式下:變量按位置匹配,集合可與元組比較。
  • timeout_seconds: 單次比較操作的最大超時時間(秒),默認為 3 秒。
  1. 定義內部比較函數 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
    
    • 使用裝飾器 @timeout 設置超時保護,默認超時時間為 timeout_seconds
    • 比較兩個表達式:
      • 如果兩者都是 SymPy 表達式(BasicMatrixBase),則調用 sympy_expr_eq 進行比較。
      • 如果兩者都是字符串,則進行簡單的字符串比較。
  2. 定義包裝函數 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
    
    • 包裝 compare_single_extraction,捕獲并記錄任何異常,返回 False 以避免程序中斷。
  3. 處理輸入列表

    if not isinstance(gold, list):gold = [gold]
    if not isinstance(target, list):target = [target]
    
    • 如果 goldtarget 不是列表,則將其轉換為單元素列表,以便統一處理。
  4. 組合所有可能的比較

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 組合。
    • 對每個組合調用 compare_single_extraction_wrapper,如果任意一對匹配成功,則返回 True

4.3.3 format_reward函數

函數用于檢查給定的完成文本是否符合特定的格式,它驗證完成文本是否包含 <think><answer> 標簽,并且這兩個標簽的內容是非空的。

def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
  • completions: 這是一個列表,其中每個元素都是一個包含完成內容的對象(通常是字典)。假設每個完成對象的第一個元素包含一個鍵 "content",其值是需要檢查的文本。
  • kwargs: 其他關鍵字參數,這里沒有使用,但可以為未來的擴展提供靈活性。
  1. 正則表達式模式定義

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 這個正則表達式用于匹配字符串是否以 <think> 開始,緊接著是任意字符(非貪婪匹配),然后是 </think>,接著可能有任意數量的空白字符(包括換行符),最后是以 <answer> 開始并以 </answer> 結束。
    • .*? 是非貪婪匹配,確保盡可能少地匹配字符。
    • \s* 匹配零個或多個空白字符(包括換行符)。
    • re.DOTALL | re.MULTILINE 標志允許點號 . 匹配所有字符(包括換行符),并且使多行文本中的每一行都可以獨立匹配。
  2. 提取完成內容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 這里通過列表推導式從 completions 列表中提取每個完成對象的第一個元素的 "content" 字段,形成一個新的列表 completion_contents
  3. 匹配正則表達式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函數對 completion_contents 中的每個內容應用正則表達式模式。
    • matches 列表將包含 re.Match 對象(如果匹配成功)或 None(如果匹配失敗)。
  4. 生成獎勵分數

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根據匹配結果生成獎勵分數。如果匹配成功(即 match 不是 None),則返回 1.0;否則返回 0.0

示例代碼:

completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores)  # 輸出: [1.0, 0.0, 0.0, 0.0]

在這個例子中:

  • 第一個完成內容完全匹配正則表達式,因此得分為 1.0
  • 后三個完成內容不符合要求,因此得分均為 0.0

4.4 將數據集格式化為對話形式

# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")

將一個數據集中的每個示例轉換為對話格式,并確保數據集中沒有多余的列(如 messages)。

  • 輸入example 是一個字典,包含單個數據樣本的信息,其中 "problem" 鍵對應的值是用戶的問題或任務描述。
  • 輸出:返回一個新的字典,包含一個 "prompt" 鍵,其值是一個對話列表:
    • 第一條消息是系統消息,內容由 SYSTEM_PROMPT 定義。
    • 第二條消息是用戶消息,內容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法將 make_conversation 函數應用到數據集的每個示例上,生成新的對話格式。
  • 移除多余列:遍歷數據集的每個拆分(split),如果存在 "messages" 列,則將其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)

篇幅有限,訓練部分的代碼我們放到下一篇博文詳細解讀!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/895813.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/895813.shtml
英文地址,請注明出處:http://en.pswp.cn/news/895813.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【雜談】加油!!!!

為了在三月底前系統準備Java后端開發的面試和筆試&#xff0c;以下是分階段的高效學習計劃&#xff1a; 一、知識體系構建&#xff08;第1-2周&#xff09; 核心基礎強化 Java基礎&#xff08;每日1.5小時&#xff09;&#xff1a; 重點掌握&#xff1a;JVM內存模型&#xff0…

python旅游推薦系統+爬蟲+可視化(協同過濾算法)

??基于用戶的協同過濾算法 ??有后臺管理 ??2w多數據集 這個旅游數據分析推薦系統采用了Python語言、Django框架、MySQL數據庫、requests庫進行網絡爬蟲開發、機器學習中的協同過濾算法、ECharts數據可視化技術&#xff0c;以實現從網站抓取旅游數據、個性化推薦和直觀展…

HarmonyNext上傳用戶相冊圖片到服務器

圖片選擇就不用說了&#xff0c;直接用 無須申請權限 。 上傳圖片&#xff0c;步驟和android對比稍微有點復雜&#xff0c;可能是為了安全性考慮&#xff0c;需要將圖片先拷貝到緩存目錄下面&#xff0c;然后再上傳&#xff0c;當然你也可以轉成Base64&#xff0c;然后和服務…

同為科技智能PDU助力Deepseek人工智能和數據交互的快速發展

1 2025開年&#xff0c;人工智能領域迎來了一場前所未有的變革。Deepseek成為代表“東方力量”的開年王炸&#xff0c;不僅在國內掀起了技術熱潮&#xff0c;并且在全球范圍內引起了高度關注。Deepseek以顛覆性技術突破和現象級應用場景席卷全球&#xff0c;這不僅重塑了產業格…

二、QEMU NFS 環境搭建

? 在上一章節中&#xff0c;我們已經成功完成了內核和 busybox 環境的配置。為了進一步提高開發效率&#xff0c;我們可以使用 NFS&#xff08;Network File System&#xff09;來掛載根目錄。NFS 允許我們將本地文件系統通過網絡共享給虛擬機使用&#xff0c;這樣在開發過程中…

.NET 9.0 的 Blazor Web App 項目中 EF Core 【事務】使用備忘

一、DbContext.Database.BeginTransactionAsync() 模式 1. 注意事項&#xff1a;連接字符串中啟用了 MARS&#xff08;Multiple Active Result Sets&#xff1a;MultipleActiveResultSetsTrue &#xff09;后&#xff0c;無法創建 保存點&#xff08;保存點與 SQL Server 的多…

記一次 Git Fetch 后切換分支為空的情況

Git Fetch 后切換分支為空的情況 在使用 Git 時&#xff0c;我遇到這樣的情況&#xff1a;執行 git fetch 后切換分支&#xff0c;發現工作目錄是空的&#xff0c;沒有任何文件&#xff0c;所以插眼記錄一下。 原因分析 git fetch 的作用&#xff1a;git fetch 只會從遠程倉庫…

UMLS數據下載及訪問

UMLS數據申請 這個直接在官網上申請即可&#xff0c;記得把地址填全&#xff0c;基本都會拿到lisence。 UMLS數據訪問 UMLS的數據訪問分為網頁訪問&#xff0c;API訪問以及數據下載后的本地訪問&#xff0c;網頁訪問&#xff0c;API訪問按照官網的指示即可&#xff0c;這里主…

使用 Docker 部署 Apache Spark 集群教程

簡介 Apache Spark 是一個強大的統一分析引擎&#xff0c;用于大規模數據處理。本文將詳細介紹如何使用 Docker 和 Docker Compose 快速部署一個包含一個 Master 節點和兩個 Worker 節點的 Spark 集群。這種方法不僅簡化了集群的搭建過程&#xff0c;還提供了資源隔離、易于擴…

瑞薩RA-T系列芯片ADCGPT功能模塊的配合使用

在馬達或電源工程中&#xff0c;往往需要采集多路AD信號&#xff0c;且這些信號的優先級和采樣時機不相同。本篇介紹在使用RA-T系列芯片建立馬達或電源工程時&#xff0c;如何根據需求來設置主要功能模塊ADC&GPT&#xff0c;包括采樣通道打包和分組&#xff0c;GPT觸發啟動…

20250217 隨筆 redis非原子性操作簡述

從你提供的文本來看&#xff0c;核心是 Redis 作為緩存的檢查機制&#xff0c;以及非原子性操作導致的不一致性問題。 我們可以拆解為兩個部分來理解&#xff1a; &#x1f4cc; 1. 邏輯&#xff1a;先查 Redis&#xff0c;再決定是否注冊 邏輯流程 先查詢 Redis 是否有某個 …

git-提交時間和作者時間的區別

1.介紹 定義介紹 提交時間&#xff08;Committer Date&#xff09;&#xff1a;決定了提交在 Git 歷史中的位置&#xff0c;通常影響 GitHub 上提交顯示的順序。 作者時間&#xff08;Author Date&#xff09;&#xff1a;雖然不影響提交的排序&#xff0c;但在每個提交詳情頁…

PHP框架入門指南:從零構建現代Web應用

一、為什么需要PHP框架? 1.1 傳統PHP開發的痛點 重復造輪子:用戶認證、表單驗證等基礎功能需要反復開發代碼混亂:缺乏統一結構導致維護困難安全漏洞:手動處理SQL注入/XSS攻擊效率低下擴展性差:耦合代碼難以適應業務增長1.2 框架的核心價值 標準化架構:MVC模式強制代碼分…

Leetcode 146 LRU緩存 的三種解法

146. LRU 緩存 請你設計并實現一個滿足 LRU (最近最少使用) 緩存 約束的數據結構。 實現 LRUCache 類&#xff1a; LRUCache(int capacity) 以 正整數 作為容量 capacity 初始化 LRU 緩存int get(int key) 如果關鍵字 key 存在于緩存中&#xff0c;則返回關鍵字的值&#xff0…

尚硅谷 java 學習Day19 抽象類與抽象方法、接口、內部類

6-5 抽象類(abstract)與抽象方法&#xff08;important&#xff09; 一、什么叫抽象類&#xff1a; 有時候將一個父類設計的非常抽象&#xff0c;以至于它沒有具體的實例&#xff0c;這樣的類稱為抽象類 abstract關鍵字的使用&#xff1a; ? 1、abstract:抽象的 ? 2、abs…

【LeetCode Hot100 鏈表(上)】相交鏈表、反轉鏈表、回文鏈表、環形鏈表、合并兩個有序鏈表、兩數相加

鏈表 1. 相交鏈表問題描述解決思路代碼實現 2. 反轉鏈表問題描述解決思路代碼實現 3. 回文鏈表問題描述解決思路代碼實現 4. 環形鏈表問題描述解決思路代碼實現 5. 環形鏈表II問題描述解決思路代碼實現 6. 合并兩個有序鏈表問題描述解決思路代碼實現 7. 兩數相加問題描述解決思…

【Python pro】基本數據類型

一、數字類型 1.1 數字類型的組成 1.1.1 整數 &#xff08;1&#xff09;十進制&#xff0c;二進制0b&#xff0c;八進制0o&#xff0c;十六進制0x print(16 0b10000 0o20 0x10) # 輸出&#xff1a;True&#xff08;2&#xff09;十進制轉其他進制 a bin(16) b oct(1…

拯救者電腦在重裝系統之后電源計劃丟失Fn+Q切換不了模式怎么恢復?

參考聯想知識庫的一下鏈接&#xff1a; https://iknow.lenovo.com.cn/detail/196192 其中下載的解壓文件后的文件需要復制粘貼到D盤的根目錄下&#xff0c;再來運行文件。若在生成的log文件中看到導入成功以及控制面板中看到已添加的電源計劃即可 如果還是無效可是試試以下的…

ubuntu 執行 sudo apt-get update 報錯

記錄一下&#xff0c;遇到這個問題了&#xff0c;網絡上看到的解決辦法&#xff0c;親測有效 執行sudo apt-get update ,卻報以下錯誤&#xff0c;“SECURITY: URL redirect target contains control characters rejecting ” 經檢查發現&#xff0c;/etc/apt/source.list 下的…

深度集成DeepSeek大模型:WebSocket流式聊天實現

目錄 5分鐘快速接入DeepSeek大模型&#xff1a;WebSocket實時聊天指南創建應用開發后端代碼 (Python/Node.js)結語 5分鐘快速接入DeepSeek大模型&#xff1a;WebSocket實時聊天指南 創建應用 訪問DeepSeek官網 前往 DeepSeek官網。如果還沒有賬號&#xff0c;需要先注冊一個。…