VLM(視覺語言模型)與DeepSeek R1(獎勵機制)如何結合
flyfish
VLM的傳統訓練依賴于監督學習(直接擬合問答對),而規則獎勵函數通常用于強化學習(通過試錯和獎勵反饋優化策略)。這兩種方式如何結合?
源碼來自
VLM-R1/src/open-r1-multimodal/src/open_r1/grpo_rec.py
# 導入 debugpy 庫,用于調試,當前代碼中被注釋掉,若需要調試可取消注釋
# import debugpy
# try:
# # 5678 是 VS Code 調試配置中的默認附加端口。除非指定主機和端口,否則主機默認為 127.0.0.1
# debugpy.listen(("localhost", 9501))
# print("Waiting for debugger attach")
# debugpy.wait_for_client()
# except Exception as e:
# pass# 導入操作系統相關功能的庫
import os
# 導入正則表達式庫,用于字符串匹配和處理
import re
# 導入日期時間處理庫
from datetime import datetime
# 導入數據類裝飾器和字段定義類,用于定義數據類
from dataclasses import dataclass, field
# 導入可選類型注解,用于表示某個參數可以為 None
from typing import Optional# 導入 Pillow 庫中的 Image 類,用于處理圖像
from PIL import Image
# 導入 PyTorch 中的數據集基類
from torch.utils.data import Dataset
# 導入 Qwen2VL 條件生成模型
from transformers import Qwen2VLForConditionalGeneration# 導入自定義的數學驗證模塊中的解析和驗證函數
from math_verify import parse, verify
# 導入自定義的 Qwen2VLGRPOTrainer 類
from open_r1.trainer import Qwen2VLGRPOTrainer
# 導入 TRL 庫中的 GRPO 配置、訓練器、模型配置、腳本參數、解析器和 PEFT 配置獲取函數
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
# 導入 Transformers 庫中的訓練參數類
from transformers import TrainingArguments
# 導入 YAML 文件處理庫
import yaml
# 導入 JSON 文件處理庫
import json
# 導入隨機數生成庫
import random
# 導入數學計算庫
import math# ----------------------- 修復當前版本 transformers 中的 flash attention 錯誤 -----------------------
# 導入 Qwen2_5_VL 模型中的相關類和函數
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
# 導入 PyTorch 庫
import torch
# 導入元組類型注解
from typing import Tuple# 自定義 Qwen2_5_VLVisionFlashAttention2 類的前向傳播函數
def custom_forward(self,hidden_states: torch.Tensor,cu_seqlens: torch.Tensor,rotary_pos_emb: Optional[torch.Tensor] = None,position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,) -> torch.Tensor:# 獲取隱藏狀態的序列長度seq_length = hidden_states.shape[0]# 通過 qkv 層得到查詢、鍵、值張量,并進行形狀調整和維度置換q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)# 如果沒有提供位置嵌入,則根據旋轉位置嵌入計算余弦和正弦值if position_embeddings is None:# 打印一次警告信息,提示 RoPE 嵌入計算方式的變化logger.warning_once("The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be ""removed and `position_embeddings` will be mandatory.")# 拼接旋轉位置嵌入emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)# 計算余弦值cos = emb.cos().float()# 計算正弦值sin = emb.sin().float()else:# 從位置嵌入中獲取余弦和正弦值cos, sin = position_embeddings# 將余弦值轉換為浮點類型cos = cos.to(torch.float)# 將正弦值轉換為浮點類型sin = sin.to(torch.float)# 應用旋轉位置嵌入到查詢和鍵張量q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)# 去除查詢張量的額外維度q = q.squeeze(0)# 去除鍵張量的額外維度k = k.squeeze(0)# 計算最大序列長度max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()# 調用 flash 注意力函數計算注意力輸出attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1)# 通過投影層得到最終的注意力輸出attn_output = self.proj(attn_output)return attn_output# 將自定義的前向傳播函數賦值給 Qwen2_5_VLVisionFlashAttention2 類的 forward 方法
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward# ----------------------- 主腳本 -----------------------
# 定義 GRPOScriptArguments 數據類,繼承自 ScriptArguments
@dataclass
class GRPOScriptArguments(ScriptArguments):"""用于 GRPO 訓練腳本的腳本參數。參數:reward_funcs (`list[str]`):獎勵函數列表。可能的值: 'accuracy', 'format'。"""# 獎勵函數列表,默認包含 'accuracy' 和 'format'reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"})# 圖像的最大像素數,默認為 12845056max_pixels: Optional[int] = field(default=12845056,metadata={"help": "Maximum number of pixels for the image"})# 圖像的最小像素數,默認為 3136min_pixels: Optional[int] = field(default=3136,metadata={"help": "Minimum number of pixels for the image"})# 圖像的根目錄,默認為 Noneimage_root: Optional[str] = field(default=None,metadata={"help": "Root directory of the image"})# 定義系統提示信息,用于指導模型的對話生成
SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)# 定義 LazySupervisedDataset 類,繼承自 Dataset
class LazySupervisedDataset(Dataset):def __init__(self, data_path: str, script_args: GRPOScriptArguments):# 調用父類的構造函數super(LazySupervisedDataset, self).__init__()# 保存腳本參數self.script_args = script_args# 初始化數據字典列表self.list_data_dict = []# 如果數據文件是 YAML 格式if data_path.endswith(".yaml"):# 打開 YAML 文件with open(data_path, "r") as file:# 加載 YAML 數據yaml_data = yaml.safe_load(file)# 獲取數據集列表datasets = yaml_data.get("datasets")# 文件格式應為:# datasets:# - json_path: xxxx1.json# sampling_strategy: first:1000# - json_path: xxxx2.json# sampling_strategy: end:3000# - json_path: xxxx3.json# sampling_strategy: random:999# 遍歷每個數據集for data in datasets:# 獲取 JSON 文件路徑json_path = data.get("json_path")# 獲取采樣策略,默認為 'all'sampling_strategy = data.get("sampling_strategy", "all")# 初始化采樣數量為 Nonesampling_number = None# 如果 JSON 文件是 JSONL 格式if json_path.endswith(".jsonl"):# 初始化當前數據字典列表cur_data_dict = []# 打開 JSONL 文件with open(json_path, "r") as json_file:# 逐行讀取文件for line in json_file:# 解析每行 JSON 數據并添加到當前數據字典列表cur_data_dict.append(json.loads(line.strip()))# 如果 JSON 文件是 JSON 格式elif json_path.endswith(".json"):# 打開 JSON 文件with open(json_path, "r") as json_file:# 加載 JSON 數據到當前數據字典列表cur_data_dict = json.load(json_file)else:# 如果文件類型不支持,拋出異常raise ValueError(f"Unsupported file type: {json_path}")# 如果采樣策略包含冒號if ":" in sampling_strategy:# 分割采樣策略和采樣數量sampling_strategy, sampling_number = sampling_strategy.split(":")# 如果采樣數量包含百分比符號if "%" in sampling_number:# 計算采樣數量sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)else:# 將采樣數量轉換為整數sampling_number = int(sampling_number)# 應用采樣策略if sampling_strategy == "first" and sampling_number is not None:# 取前 sampling_number 個樣本cur_data_dict = cur_data_dict[:sampling_number]elif sampling_strategy == "end" and sampling_number is not None:# 取后 sampling_number 個樣本cur_data_dict = cur_data_dict[-sampling_number:]elif sampling_strategy == "random" and sampling_number is not None:# 隨機打亂樣本random.shuffle(cur_data_dict)# 取前 sampling_number 個樣本cur_data_dict = cur_data_dict[:sampling_number]# 打印從當前 JSON 文件加載的樣本數量print(f"Loaded {len(cur_data_dict)} samples from {json_path}")# 將當前數據字典列表添加到總數據字典列表self.list_data_dict.extend(cur_data_dict)else:# 如果文件類型不支持,拋出異常raise ValueError(f"Unsupported file type: {data_path}")def __len__(self):# 返回數據字典列表的長度return len(self.list_data_dict)def __getitem__(self, i):# 定義將示例轉換為對話格式的函數def make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]}]}# 問題模板,用于包含圖像的對話QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."# 定義將包含圖像的示例轉換為對話格式的函數def make_conversation_image(example):return {"prompt": [# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},{"role": "user","content": [{"type": "image"},{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}]}]}# 獲取指定索引的示例example = self.list_data_dict[i]# 獲取圖像根目錄image_root = self.script_args.image_root# 如果示例中包含圖像信息if 'image' in example:# 構建圖像路徑image_path = os.path.join(image_root, example['image'])# 如果圖像文件不存在while not os.path.exists(image_path):# 打印警告信息print(f"Warning: Image {image_path} not found, randomly selecting another image")# 隨機選擇一個新的索引new_index = random.randint(0, len(self.list_data_dict)-1)# 獲取新的示例example = self.list_data_dict[new_index]# 構建新的圖像路徑image_path = os.path.join(image_root, example['image'])# 打開圖像并轉換為 RGB 格式image = Image.open(image_path).convert("RGB")else:# 如果示例中不包含圖像信息,圖像為 Noneimage = Nonereturn {'image': image,'problem': example['problem'],'solution': example['solution'],'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt']}'''如果模型預測的邊界框與真實邊界框的交并比(IoU)大于 0.5,則獎勵為 1.0,否則為 0.0。這是一種硬獎勵,未來可能使用軟獎勵會更好。
'''
def iou_reward(completions, solution, **kwargs):# 定義計算交并比的函數def iou(box1, box2):# 計算交集的左上角坐標inter_x1 = max(box1[0], box2[0])inter_y1 = max(box1[1], box2[1])# 計算交集的右下角坐標inter_x2 = min(box1[2]-1, box2[2]-1)inter_y2 = min(box1[3]-1, box2[3]-1)# 如果交集存在if inter_x1 < inter_x2 and inter_y1 < inter_y2:# 計算交集面積inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)else:# 交集面積為 0inter = 0# 計算并集面積union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter# 返回交并比return float(inter)/union# 獲取完成內容列表contents = [completion[0]["content"] for completion in completions]# 初始化獎勵列表rewards = []# 獲取當前時間并格式化current_time = datetime.now().strftime("%d-%H-%M-%S-%f")# 定義答案標簽的正則表達式模式answer_tag_pattern = r'<answer>(.*?)</answer>'# 定義邊界框的正則表達式模式bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'# 遍歷完成內容和真實解決方案for content, sol in zip(contents, solution):# 初始化獎勵為 0.0reward = 0.0# 嘗試進行符號驗證try:# 在完成內容中查找答案標簽content_answer_match = re.search(answer_tag_pattern, content)if content_answer_match:# 獲取答案內容content_answer = content_answer_match.group(1).strip()# 在答案內容中查找邊界框bbox_match = re.search(bbox_pattern, content_answer)if bbox_match:# 獲取邊界框坐標bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]# 如果交并比大于 0.5if iou(bbox, sol) > 0.5:# 獎勵為 1.0reward = 1.0except Exception:# 如果驗證失敗,繼續下一個驗證方法pass# 將獎勵添加到獎勵列表rewards.append(reward)# 如果處于調試模式if os.getenv("DEBUG_MODE") == "true":# 獲取日志路徑log_path = os.getenv("LOG_PATH")# 打開日志文件并追加記錄with open(log_path, "a") as f:# 記錄當前時間和獎勵信息f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")# 記錄完成內容f.write(f"Content: {content}\n")# 記錄真實解決方案f.write(f"Solution: {sol}\n")return rewardsdef format_reward(completions, **kwargs):"""獎勵函數,用于檢查完成內容是否符合特定格式。"""# 定義格式的正則表達式模式# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"# 獲取完成內容列表completion_contents = [completion[0]["content"] for completion in completions]# 檢查每個完成內容是否符合格式matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]# 根據匹配結果生成獎勵列表return [1.0 if match else 0.0 for match in matches]# 獎勵函數注冊表,將獎勵函數名稱映射到對應的函數
reward_funcs_registry = {"accuracy": iou_reward,"format": format_reward,
}def main(script_args, training_args, model_args):# 根據腳本參數中的獎勵函數名稱,從注冊表中獲取對應的獎勵函數reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]# 打印獎勵函數列表print("reward_funcs:", reward_funcs)# 加載數據集dataset = LazySupervisedDataset(script_args.dataset_name, script_args)# 選擇訓練器類,這里使用自定義的 Qwen2VLGRPOTrainertrainer_cls = Qwen2VLGRPOTrainer# 初始化 GRPO 訓練器trainer = trainer_cls(model=model_args.model_name_or_path, # 模型名稱或路徑reward_funcs=reward_funcs, # 獎勵函數列表args=training_args, # 訓練參數train_dataset=dataset, # 訓練數據集eval_dataset=None, # 評估數據集,這里設為 Nonepeft_config=get_peft_config(model_args), # PEFT 配置attn_implementation=model_args.attn_implementation, # 注意力實現方式max_pixels=script_args.max_pixels, # 圖像最大像素數min_pixels=script_args.min_pixels, # 圖像最小像素數torch_dtype=model_args.torch_dtype, # PyTorch 數據類型)# 開始訓練模型trainer.train()# 保存模型到指定的輸出目錄trainer.save_model(training_args.output_dir)# 如果設置了將模型推送到 Hubif training_args.push_to_hub:# 將模型推送到 Hub,并指定數據集名稱trainer.push_to_hub(dataset_name=script_args.dataset_name)if __name__ == "__main__":# 創建 TrlParser 對象,用于解析腳本參數、訓練配置和模型配置parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))# 解析命令行參數和配置script_args, training_args, model_args = parser.parse_args_and_config()# 調用主函數開始訓練main(script_args, training_args, model_args)
代碼中的兩個關鍵獎勵函數 format_reward
和 iou_reward
。
1. 格式獎勵函數 format_reward
函數定義和功能
def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
此函數的主要功能是檢查模型生成的完成內容是否符合特定的格式要求。具體來說,它期望模型的輸出滿足以下格式:
- 包含
<think>
和</think>
標簽,用于包裹思考過程。 - 包含
<answer>
和</answer>
標簽,用于包裹答案。 - 答案部分需要是一個 JSON 格式,并且其中包含一個由四個整數組成的列表,通常可以理解為表示邊界框的坐標。
實現步驟
- 定義正則表達式模式:
pattern
是一個正則表達式,用于描述期望的輸出格式。 - 提取完成內容:
completion_contents
從completions
中提取出每個完成內容的文本部分。 - 檢查格式匹配:
matches
使用re.fullmatch
函數檢查每個完成內容是否完全匹配正則表達式模式。 - 生成獎勵列表:根據匹配結果,為每個完成內容生成一個獎勵值,如果匹配則為 1.0,否則為 0.0。
作用
通過這個獎勵函數,模型在訓練過程中會被激勵去生成符合特定格式的輸出,有助于規范模型的回答結構,使得輸出更易于解析和使用。
2. 交并比(IoU)獎勵函數 iou_reward
函數定義和功能
def iou_reward(completions, solution, **kwargs):def iou(box1, box2):inter_x1 = max(box1[0], box2[0])inter_y1 = max(box1[1], box2[1])inter_x2 = min(box1[2]-1, box2[2]-1)inter_y2 = min(box1[3]-1, box2[3]-1)if inter_x1 < inter_x2 and inter_y1 < inter_y2:inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)else:inter = 0union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - interreturn float(inter)/unioncontents = [completion[0]["content"] for completion in completions]rewards = []current_time = datetime.now().strftime("%d-%H-%M-%S-%f")answer_tag_pattern = r'<answer>(.*?)</answer>'bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'for content, sol in zip(contents, solution):reward = 0.0try:content_answer_match = re.search(answer_tag_pattern, content)if content_answer_match:content_answer = content_answer_match.group(1).strip()bbox_match = re.search(bbox_pattern, content_answer)if bbox_match:bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]if iou(bbox, sol) > 0.5:reward = 1.0except Exception:passrewards.append(reward)if os.getenv("DEBUG_MODE") == "true":log_path = os.getenv("LOG_PATH")with open(log_path, "a") as f:f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")f.write(f"Content: {content}\n")f.write(f"Solution: {sol}\n")return rewards
此函數的主要功能是評估模型預測的邊界框與真實邊界框之間的重疊程度,并根據交并比(IoU)值給予獎勵。
實現步驟
- 定義 IoU 計算函數:
iou
函數用于計算兩個邊界框的交并比。它首先計算兩個邊界框的交集面積和并集面積,然后將交集面積除以并集面積得到 IoU 值。 - 提取完成內容:
contents
從completions
中提取出每個完成內容的文本部分。 - 查找答案和邊界框:使用正則表達式
answer_tag_pattern
查找完成內容中的答案部分,再使用bbox_pattern
查找答案中的邊界框坐標。 - 計算 IoU 并給予獎勵:對于每個完成內容,提取預測的邊界框坐標,與真實邊界框計算 IoU 值。如果 IoU 值大于 0.5,則給予 1.0 的獎勵,否則給予 0.0 的獎勵。
- 日志記錄(可選):如果設置了調試模式(
DEBUG_MODE
為true
),則將每個完成內容的獎勵信息記錄到日志文件中。
作用
通過這個獎勵函數,模型在訓練過程中會被激勵去預測更準確的邊界框,提高目標檢測的精度。同時,結合格式獎勵函數,可以讓模型不僅準確預測邊界框,還能以規定的格式輸出結果。
監督學習與規則獎勵函數強化學習的結合方式
1. 數據層面的結合
- 利用監督數據初始化模型:在開始強化學習訓練之前,使用監督學習的方式對視覺語言模型(VLM)進行預訓練。通過直接擬合問答對數據,讓模型學習到基本的語言和視覺特征表示以及問題回答的模式。例如,在代碼中使用
LazySupervisedDataset
類加載數據集,這些數據可以作為監督學習階段的訓練數據,讓模型初步學習到如何根據問題和圖像生成答案。 - 監督數據作為強化學習的參考:在強化學習的過程中,監督學習的數據可以作為參考來評估模型的輸出。例如,在
iou_reward
函數中,通過比較模型預測的邊界框與真實邊界框的交并比(IoU)來給予獎勵,這里的真實邊界框就是監督學習中的標簽信息。
2. 訓練過程的結合
- 分階段訓練:先進行監督學習訓練,讓模型收斂到一個較好的初始狀態。然后再切換到強化學習階段,使用規則獎勵函數來進一步優化模型的策略。在代碼中,雖然沒有明確體現分階段訓練的邏輯,但可以在實際應用中先使用監督學習的方法對
Qwen2VLForConditionalGeneration
模型進行訓練,然后再使用Qwen2VLGRPOTrainer
進行強化學習訓練。 - 混合訓練:在每個訓練步驟中,既使用監督學習的損失函數,又使用強化學習的獎勵函數。例如,可以將監督學習的交叉熵損失和強化學習的獎勵損失加權求和,作為總的損失函數來更新模型參數。這樣可以讓模型在學習過程中既考慮到直接擬合標簽的準確性,又考慮到長期的獎勵優化。
3. 獎勵函數設計結合監督信息
- 準確性獎勵:如
iou_reward
函數,將模型輸出與監督學習中的標簽進行比較,根據比較結果給予獎勵。這種獎勵函數可以促使模型在強化學習過程中輸出更接近真實標簽的結果,從而結合了監督學習的信息。 - 格式獎勵:
format_reward
函數可以確保模型輸出的格式符合特定要求,這可以看作是一種規則約束。同時,這種格式要求也可以是在監督學習階段就定義好的,從而將監督學習中的格式規范融入到強化學習的獎勵機制中。