📌 本文介紹如何在 RLHF(Reinforcement Learning with Human Feedback)中使用 PPO(Proximal Policy Optimization)算法對語言模型進行強化學習微調。
🔗 官方文檔:trl PPOTrainer
一、引言:PPO 在 RLHF 中的角色
????????PPO(Proximal Policy Optimization)是一種常用的強化學習優化算法,它在 RLHF 的第三階段發揮核心作用:通過人類偏好訓練出的獎勵模型對語言模型行為進行優化。我們將在本篇中詳細介紹如何基于 Hugging Face 的 trl 庫,結合 IMDb 數據集、情感分析獎勵模型,完成完整的 PPO 訓練流程。
二、環境依賴
pip install peft trl accelerate datasets transformers
三、配置 PPOConfig
from trl import PPOConfigppo_config = PPOConfig(model_name="lvwerra/gpt2-imdb",query_dataset="imdb",reward_model="sentiment-analysis:lvwerra/distilbert-imdb",learning_rate=1.41e-5,log_with=None,mini_batch_size=128,batch_size=128,target_kl=6.0,kl_penalty="kl",seed=0,
)
四、構建數據集與 Tokenizer
from datasets import load_dataset
from transformers import AutoTokenizer
from trl.core import LengthSamplerdef build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):tokenizer = AutoTokenizer.from_pretrained(config.model_name, use_fast=True)tokenizer.pad_token = tokenizer.eos_tokends = load_dataset(query_dataset, split="train")ds = ds.rename_columns({"text": "review"})ds = ds.filter(lambda x: len(x["review"]) > 200)input_size = LengthSampler(input_min_text_length, input_max_text_length)def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sampleds = ds.map(tokenize)ds.set_format(type="torch")return dsdataset = build_dataset(ppo_config, ppo_config.query_dataset)
五、加載模型與參考模型(Ref Model)
from trl import AutoModelForCausalLMWithValueHeadmodel_cls = AutoModelForCausalLMWithValueHead
model = model_cls.from_pretrained(ppo_config.model_name)
ref_model = model_cls.from_pretrained(ppo_config.model_name)tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
六、構建 PPOTrainer 與獎勵模型
from trl import PPOTrainer
from transformers import pipelinedef collator(data):return dict((key, [d[key] for d in data]) for key in data[0])ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
構建情感獎勵模型
task, model_name = ppo_config.reward_model.split(":")
sentiment_pipe = pipeline(task, model=model_name, device=1 if torch.cuda.is_available() else "cpu", return_all_scores=True, function_to_apply="none", batch_size=16
)# 確保 tokenizer 設置 pad_token_id
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id
七、執行 PPO 訓練循環
from tqdm.auto import tqdm
import torchgeneration_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 32,
}for step, batch in enumerate(tqdm(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs)batch["response"] = tokenizer.batch_decode(response_tensors)batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)texts = [q + r for q, r in zip(batch["query"], batch["response"])]rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(texts)]ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]ref_rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(ref_texts)]batch["ref_rewards"] = ref_rewardsstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])
八、總結與展望
????????在本篇文章中,我們實現了以下核心步驟:
階段 | 描述 |
---|---|
數據構建 | 利用 IMDb 構造簡短語料用于語言生成 |
模型構建 | 加載 GPT2 并構建 Value Head 以評估獎勵 |
獎勵模型 | 使用 DistilBERT 進行情感打分作為獎勵信號 |
PPO 訓練 | 利用 TRL 中的 PPOTrainer 實現語言強化優化 |
????????PPO 是 RLHF 中至關重要的一環,在人類反饋基礎上不斷微調模型的輸出質量,是當前 ChatGPT、Claude 等大模型背后的關鍵技術之一。
????????📘 下一篇預告:《基于 Python 的自然語言處理系列(86):DPO(Direct Preference Optimization)原理與實戰》
????????相比傳統 RLHF 流程,DPO 提供了一種更簡潔、無需獎勵模型與 PPO 的替代方案,敬請期待!
如果你覺得這篇博文對你有幫助,請點贊、收藏、關注我,并且可以打賞支持我!
歡迎關注我的后續博文,我將分享更多關于人工智能、自然語言處理和計算機視覺的精彩內容。
謝謝大家的支持!