目錄
- 4 GRPO源碼分析
- 4.1 數據類 `GRPOScriptArguments`
- 4.2 系統提示字符串 `SYSTEM_PROMPT`
- 4.3 獎勵函數
- 4.3.1 accuracy_reward函數
- 4.3.2 verify函數
- 4.3.3 format_reward函數
- 4.4 將數據集格式化為對話形式
- 4.5 初始化GRPO Trainer
【復現DeepSeek-R1之Open R1實戰】系列3:SFT和GRPO源碼逐行深度解析(上)
【復現DeepSeek-R1之Open R1實戰】系列5:SFT和GRPO源碼逐行深度解析(中)
4 GRPO源碼分析
前面兩篇博文已經詳細介紹了一些基礎知識和SFT源碼,本文繼續解讀GRPO源碼。與SFT源碼差不多的部分,我們就不展開細說了,這里只解析GRPO獨特的部分。
4.1 數據類 GRPOScriptArguments
該類使用了 Python 的 dataclass
裝飾器,這是一種簡化類定義的方式,特別是對于那些主要用來存儲數據的類。它繼承自 ScriptArguments
類。
-
reward_funcs: 這是一個列表,包含了一系列可能的獎勵函數名稱,默認值為
["accuracy", "format"]
。這些獎勵函數可能是用于評估模型性能的不同標準。reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"}, )
-
cosine_min_value_wrong 和 cosine_max_value_wrong: 分別表示錯誤答案在余弦相似度尺度上的最小和最大獎勵值,默認分別為
0.0
和-0.5
。 -
cosine_min_value_correct 和 cosine_max_value_correct: 分別表示正確答案在余弦相似度尺度上的最小和最大獎勵值,默認分別為
0.5
和1.0
。 -
cosine_max_len: 表示余弦相似度尺度的最大長度,默認值為
1000
。 -
repetition_n_grams: 表示用于重復懲罰獎勵的n-gram數量,默認值為
3
。 -
repetition_max_penalty: 表示重復懲罰獎勵的最大負值,默認值為
-1.0
。
每個字段都使用了 field()
函數來定義其默認值和元數據(如幫助信息)。這有助于工具和庫更好地理解和處理這些字段,例如生成命令行解析器時。
4.2 系統提示字符串 SYSTEM_PROMPT
SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)
字符串描述了一個對話場景,用戶先提問,助手首先思考推理過程,然后提供答案。推理過程和答案分別用 <think>
和 <answer>
標簽包裹,這種格式化有助于區分和識別不同的部分,和DeepSeek-R1的思考過程格式一致。
4.3 獎勵函數
獎勵函數的定義如下,GRPO默認用到了accuracy_reward和format_reward這兩個函數。
# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
這段代碼定義了一個獎勵函數注冊表 REWARD_FUNCS_REGISTRY
,并根據用戶提供的配置動態生成一個獎勵函數列表 reward_funcs
。每個獎勵函數用于評估模型輸出的不同方面,如準確性、格式、推理步驟等。
- 注冊表定義
accuracy
: 使用accuracy_reward
函數評估模型輸出的準確性。format
: 使用format_reward
函數評估模型輸出的格式。reasoning_steps
: 使用reasoning_steps_reward
函數評估模型輸出的推理步驟。cosine
: 使用get_cosine_scaled_reward
函數計算余弦相似度獎勵,參數包括:min_value_wrong
: 錯誤情況下的最小值。max_value_wrong
: 錯誤情況下的最大值。min_value_correct
: 正確情況下的最小值。max_value_correct
: 正確情況下的最大值。max_len
: 最大長度。
repetition_penalty
: 使用get_repetition_penalty_reward
函數計算重復懲罰獎勵,參數包括:ngram_size
: n-gram 的大小。max_penalty
: 最大懲罰值。
length
: 使用len_reward
函數評估模型輸出的長度。
- 動態生成獎勵函數列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
- 根據
script_args.reward_funcs
中指定的獎勵函數名稱,從REWARD_FUNCS_REGISTRY
中獲取相應的獎勵函數,并生成一個列表reward_funcs
。
4.3.1 accuracy_reward函數
該函數用于計算模型生成的補全與真實答案之間的準確性獎勵。它通過解析和驗證生成的內容與真實答案來確定獎勵值。
def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
- completions (
list
): 包含多個補全結果的列表,每個補全結果是一個包含內容的字典列表。 - solution (
list
): 真實答案的列表。 - kwargs: 其他可選參數(在本函數中未使用)。
-
提取補全內容
contents = [completion[0]["content"] for completion in completions]
- 從
completions
列表中提取每個補全的第一個內容(假設每個補全是單個元素的列表),形成一個新的contents
列表。
- 從
-
初始化獎勵列表
rewards = []
-
遍歷每個補全和對應的真實答案
for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
- 使用
zip
函數將contents
和solution
配對。 - 對于每一對補全內容和真實答案,首先解析真實答案
sol
,使用parse
函數提取其中的信息。
- 使用
-
處理解析結果
if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
- 如果解析得到的真實答案
gold_parsed
非空,則繼續解析生成的補全內容content
。 - 使用
LatexExtractionConfig
和NormalizationConfig
進行詳細配置,確保解析過程中考慮了各種格式要求(如方程、單位等)。
- 如果解析得到的真實答案
-
計算獎勵
reward = float(verify(answer_parsed, gold_parsed))
- 使用
verify
函數比較生成的補全解析結果和真實答案的解析結果。 - 如果兩者匹配,則返回
1.0
,否則返回0.0
。
- 使用
-
處理無法解析的情況
else:reward = 1.0print("Failed to parse gold solution: ", sol)
- 如果真實答案無法解析,則默認給予獎勵
1.0
并打印警告信息。
- 如果真實答案無法解析,則默認給予獎勵
-
添加獎勵到列表
rewards.append(reward)
-
返回所有獎勵
return rewards
4.3.2 verify函數
該函數用于驗證目標表達式是否與參考表達式匹配,它通過多種比較策略來處理不同的數學對象(如數字、表達式、集合、矩陣等),并提供靈活的配置選項以適應不同的需求。
def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
- gold: 參考或正確的表達式,可以是單個 SymPy 表達式(
Basic
或MatrixBase
)、字符串或這些類型的列表。 - target: 需要驗證的表達式,類型同
gold
。 - float_rounding: 浮點數舍入的小數位數,默認為 6。
- numeric_precision: 數值比較時考慮的小數位數,默認為 15。
- strict: 是否啟用嚴格比較模式,默認為
True
。- 在嚴格模式下:變量很重要,集合不可與元組比較。
- 在非嚴格模式下:變量按位置匹配,集合可與元組比較。
- timeout_seconds: 單次比較操作的最大超時時間(秒),默認為 3 秒。
-
定義內部比較函數
compare_single_extraction
@timeout(timeout_seconds=timeout_seconds) def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
- 使用裝飾器
@timeout
設置超時保護,默認超時時間為timeout_seconds
。 - 比較兩個表達式:
- 如果兩者都是 SymPy 表達式(
Basic
或MatrixBase
),則調用sympy_expr_eq
進行比較。 - 如果兩者都是字符串,則進行簡單的字符串比較。
- 如果兩者都是 SymPy 表達式(
- 使用裝飾器
-
定義包裝函數
compare_single_extraction_wrapper
def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
- 包裝
compare_single_extraction
,捕獲并記錄任何異常,返回False
以避免程序中斷。
- 包裝
-
處理輸入列表
if not isinstance(gold, list):gold = [gold] if not isinstance(target, list):target = [target]
- 如果
gold
或target
不是列表,則將其轉換為單元素列表,以便統一處理。
- 如果
-
組合所有可能的比較
return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
- 使用
itertools.product
生成所有可能的gold
和target
組合。 - 對每個組合調用
compare_single_extraction_wrapper
,如果任意一對匹配成功,則返回True
。
- 使用
4.3.3 format_reward函數
函數用于檢查給定的完成文本是否符合特定的格式,它驗證完成文本是否包含 <think>
和 <answer>
標簽,并且這兩個標簽的內容是非空的。
def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
- completions: 這是一個列表,其中每個元素都是一個包含完成內容的對象(通常是字典)。假設每個完成對象的第一個元素包含一個鍵
"content"
,其值是需要檢查的文本。 - kwargs: 其他關鍵字參數,這里沒有使用,但可以為未來的擴展提供靈活性。
-
正則表達式模式定義:
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
- 這個正則表達式用于匹配字符串是否以
<think>
開始,緊接著是任意字符(非貪婪匹配),然后是</think>
,接著可能有任意數量的空白字符(包括換行符),最后是以<answer>
開始并以</answer>
結束。 .*?
是非貪婪匹配,確保盡可能少地匹配字符。\s*
匹配零個或多個空白字符(包括換行符)。re.DOTALL | re.MULTILINE
標志允許點號.
匹配所有字符(包括換行符),并且使多行文本中的每一行都可以獨立匹配。
- 這個正則表達式用于匹配字符串是否以
-
提取完成內容:
completion_contents = [completion[0]["content"] for completion in completions]
- 這里通過列表推導式從
completions
列表中提取每個完成對象的第一個元素的"content"
字段,形成一個新的列表completion_contents
。
- 這里通過列表推導式從
-
匹配正則表達式:
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
- 使用
re.match
函數對completion_contents
中的每個內容應用正則表達式模式。 matches
列表將包含re.Match
對象(如果匹配成功)或None
(如果匹配失敗)。
- 使用
-
生成獎勵分數:
return [1.0 if match else 0.0 for match in matches]
- 最后一步是根據匹配結果生成獎勵分數。如果匹配成功(即
match
不是None
),則返回1.0
;否則返回0.0
。
- 最后一步是根據匹配結果生成獎勵分數。如果匹配成功(即
示例代碼:
completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores) # 輸出: [1.0, 0.0, 0.0, 0.0]
在這個例子中:
- 第一個完成內容完全匹配正則表達式,因此得分為
1.0
。 - 后三個完成內容不符合要求,因此得分均為
0.0
。
4.4 將數據集格式化為對話形式
# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")
將一個數據集中的每個示例轉換為對話格式,并確保數據集中沒有多余的列(如 messages
)。
- 輸入:
example
是一個字典,包含單個數據樣本的信息,其中"problem"
鍵對應的值是用戶的問題或任務描述。 - 輸出:返回一個新的字典,包含一個
"prompt"
鍵,其值是一個對話列表:- 第一條消息是系統消息,內容由
SYSTEM_PROMPT
定義。 - 第二條消息是用戶消息,內容是
example["problem"]
。
- 第一條消息是系統消息,內容由
dataset.map(make_conversation)
:使用map
方法將make_conversation
函數應用到數據集的每個示例上,生成新的對話格式。- 移除多余列:遍歷數據集的每個拆分(split),如果存在
"messages"
列,則將其移除。
4.5 初始化GRPO Trainer
trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)
篇幅有限,訓練部分的代碼我們放到下一篇博文詳細解讀!