基于 Python 的自然語言處理系列(85):PPO 原理與實踐

📌 本文介紹如何在 RLHF(Reinforcement Learning with Human Feedback)中使用 PPO(Proximal Policy Optimization)算法對語言模型進行強化學習微調。

🔗 官方文檔:trl PPOTrainer

一、引言:PPO 在 RLHF 中的角色

????????PPO(Proximal Policy Optimization)是一種常用的強化學習優化算法,它在 RLHF 的第三階段發揮核心作用:通過人類偏好訓練出的獎勵模型對語言模型行為進行優化。我們將在本篇中詳細介紹如何基于 Hugging Face 的 trl 庫,結合 IMDb 數據集、情感分析獎勵模型,完成完整的 PPO 訓練流程。

二、環境依賴

pip install peft trl accelerate datasets transformers

三、配置 PPOConfig

from trl import PPOConfigppo_config = PPOConfig(model_name="lvwerra/gpt2-imdb",query_dataset="imdb",reward_model="sentiment-analysis:lvwerra/distilbert-imdb",learning_rate=1.41e-5,log_with=None,mini_batch_size=128,batch_size=128,target_kl=6.0,kl_penalty="kl",seed=0,
)

四、構建數據集與 Tokenizer

from datasets import load_dataset
from transformers import AutoTokenizer
from trl.core import LengthSamplerdef build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):tokenizer = AutoTokenizer.from_pretrained(config.model_name, use_fast=True)tokenizer.pad_token = tokenizer.eos_tokends = load_dataset(query_dataset, split="train")ds = ds.rename_columns({"text": "review"})ds = ds.filter(lambda x: len(x["review"]) > 200)input_size = LengthSampler(input_min_text_length, input_max_text_length)def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sampleds = ds.map(tokenize)ds.set_format(type="torch")return dsdataset = build_dataset(ppo_config, ppo_config.query_dataset)

五、加載模型與參考模型(Ref Model)

from trl import AutoModelForCausalLMWithValueHeadmodel_cls = AutoModelForCausalLMWithValueHead
model = model_cls.from_pretrained(ppo_config.model_name)
ref_model = model_cls.from_pretrained(ppo_config.model_name)tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

六、構建 PPOTrainer 與獎勵模型

from trl import PPOTrainer
from transformers import pipelinedef collator(data):return dict((key, [d[key] for d in data]) for key in data[0])ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

構建情感獎勵模型

task, model_name = ppo_config.reward_model.split(":")
sentiment_pipe = pipeline(task, model=model_name, device=1 if torch.cuda.is_available() else "cpu", return_all_scores=True, function_to_apply="none", batch_size=16
)# 確保 tokenizer 設置 pad_token_id
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id

七、執行 PPO 訓練循環

 
from tqdm.auto import tqdm
import torchgeneration_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 32,
}for step, batch in enumerate(tqdm(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs)batch["response"] = tokenizer.batch_decode(response_tensors)batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)texts = [q + r for q, r in zip(batch["query"], batch["response"])]rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(texts)]ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]ref_rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(ref_texts)]batch["ref_rewards"] = ref_rewardsstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])

八、總結與展望

????????在本篇文章中,我們實現了以下核心步驟:

階段描述
數據構建利用 IMDb 構造簡短語料用于語言生成
模型構建加載 GPT2 并構建 Value Head 以評估獎勵
獎勵模型使用 DistilBERT 進行情感打分作為獎勵信號
PPO 訓練利用 TRL 中的 PPOTrainer 實現語言強化優化

????????PPO 是 RLHF 中至關重要的一環,在人類反饋基礎上不斷微調模型的輸出質量,是當前 ChatGPT、Claude 等大模型背后的關鍵技術之一。

????????📘 下一篇預告:《基于 Python 的自然語言處理系列(86):DPO(Direct Preference Optimization)原理與實戰》
????????相比傳統 RLHF 流程,DPO 提供了一種更簡潔、無需獎勵模型與 PPO 的替代方案,敬請期待!

如果你覺得這篇博文對你有幫助,請點贊、收藏、關注我,并且可以打賞支持我!

歡迎關注我的后續博文,我將分享更多關于人工智能、自然語言處理和計算機視覺的精彩內容。

謝謝大家的支持!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/78308.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/78308.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/78308.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

珍愛網:從降本增效到綠色低碳,數字化新基建價值凸顯

2024年12月24日,法大大聯合企業綠色發展研究院發布《2024簽約減碳與低碳辦公白皮書》,深入剖析電子簽在推動企業綠色低碳轉型中的關鍵作用,為企業實現環境、社會和治理(ESG)目標提供新思路。近期,法大大將陸…

Java實現HTML轉PDF(deepSeekAi->html->pdf)

Java實現HTML轉PDF,主要為了解決將ai返回的html文本數據轉為PDF文件方便用戶下載查看。 一、deepSeek-AI提問詞 基于以上個人數據。總結個人身體信息,分析個人身體指標信息。再按一個月為維度,詳細列舉一個月內訓練計劃,維度詳細至每周每天…

Estimands與Intercurrent Events:臨床試驗與統計學核心框架

1. Estimands(估計目標)概述 1.1 定義與作用 1.1.1 定義 Estimand是臨床試驗中需明確提出的科學問題,即研究者希望通過數據估計的“目標量”,定義“治療效應”具體含義,確保分析結果與臨床問題一致。 例如,在研究某種新藥對高血壓患者降壓效果時,Estimand可定義為“在…

Jsp技術入門指南【十】IDEA 開發環境下實現 MySQL 數據在 JSP 頁面的可視化展示,實現前后端交互

Jsp技術入門指南【十】IDEA 開發環境下實現 MySQL 數據在 JSP 頁面的可視化展示,實現前后端交互 前言一、JDBC 核心接口和類:數據庫連接的“工具箱”1. 常用的 2 個“關鍵類”2. 必須掌握的 5 個“核心接口” 二、創建 JDBC 程序的步驟1. 第一步&#xf…

深入理解HotSpot JVM 基本原理

關于JAVA Java編程語言是一種通用的、并發的、面向對象的語言。它的語法類似于C和C++,但它省略了許多使C和C++復雜、混亂和不安全的特性。 Java 是幾乎所有類型的網絡應用程序的基礎,也是開發和提供嵌入式和移動應用程序、游戲、基于 Web 的內容和企業軟件的全球標準。. 從…

【HTTP/3:互聯網通信的量子飛躍】

HTTP/3:互聯網通信的量子飛躍 如果說HTTP/1.1是鄉村公路,HTTP/2是現代高速公路系統,那么HTTP/3就像是一種革命性的"傳送門"技術,它徹底重寫了數據傳輸的底層規則,讓信息幾乎可以瞬間抵達目的地,…

Apipost免費版、企業版和私有化部署詳解

Apipost是企業級的 API 研發協作一體化平臺,為企業提供 API研發測試管理全鏈路解決方案,不止于API研發場景,增強企業API資產管理。 Apipost 基于同一份數據源,同時提供給后端開發、前端開發、測試人員使用的接口調試、Mock、自動化…

使用若依二次開發商城系統-1:搭建若依運行環境

前言 若依框架有很多版本,這里使用的是springboot3vue3這樣的一個前后端分離的版本。 一.操作步驟 1 下載springboot3版本的后端代碼 后端springboot3的代碼路徑,https://gitee.com/y_project/RuoYi-Vue 需要注意我們要的是springboot3分支。 先用g…

速成GO訪問sql,個人筆記

更多個人筆記:(僅供參考,非盈利) gitee: https://gitee.com/harryhack/it_note github: https://github.com/ZHLOVEYY/IT_note 本文是基于原生的庫 database/sql進行初步學習 基于ORM等更多操作可以關注我…

【C++指南】告別C字符串陷阱:如何實現封裝string?

🌟 各位看官好,我是egoist2023! 🌍 種一棵樹最好是十年前,其次是現在! 💬 注意:本章節只詳講string中常用接口及實現,有其他需求查閱文檔介紹。 🚀 今天通過了…

系統架構師2025年論文《論軟件架構評估2》

論軟件系統架構評估 v2.0 摘要: 某市醫院預約掛號系統建設推廣應用項目是我市衛生健康委員會 2019 年發起的一項醫療衛生行業便民惠民信息化項目,目的是實現轄區內患者在轄區各公立醫療機構就診時,可以通過多種線上渠道進行預約掛號,提升就醫體驗。我作為系統架構師參與此…

BEVDet4D: Exploit Temporal Cues in Multi-camera 3D Object Detection

背景 對于現有的BEVDet方法,它對于速度的預測誤差要高于基于點云的方法,對于像速度這種與時間有關的屬性,僅靠單幀數據很難預測好。因此本文提出了BEVDet4D,旨在獲取時間維度上的豐富信息。它是在BEVDet的基礎上進行拓展,保留了之前幀的BEV特征,并將其進行空間對齊后與當…

el-upload 上傳邏輯和ui解耦,上傳七牛

解耦的作用在于如果后面要我改成從阿里云oss上傳文件,我只需要實現上傳邏輯從七牛改成阿里云即可,其他不用動。實現方式有2部分組成,一部分是上傳邏輯,一部分是ui。 上傳邏輯 大概邏輯就是先去服務端拿上傳token和地址&#xff0…

酒水類目電商代運營公司-品融電商:全域策略驅動品牌長效增長

酒水類目電商代運營公司-品融電商:全域策略驅動品牌長效增長 在競爭日益激烈的酒水市場中,品牌如何快速突圍并實現長效增長?品融電商憑借「效品合一 全域增長」方法論與全鏈路運營能力,成為酒水類目代運營的領跑者。從品牌定位、視…

機器學習特征工程中的數值分箱技術:原理、方法與實例解析

標題:機器學習特征工程中的數值分箱技術:原理、方法與實例解析 摘要: 分箱技術作為機器學習特征工程中的關鍵環節,通過將數值數據劃分為離散區間,能夠有效提升模型對非線性關系的捕捉能力,同時增強模型對異…

【MySQL專欄】MySQL數據庫的復合查詢語句

文章目錄 1、首先練習MySQL基本語句的練習①查詢工資高于500或崗位為MANAGER的雇員,同時還要滿足他們的姓名首字母為大寫的J②按照部門號升序而雇員的工資降序排序③使用年薪進行降序排序④顯示工資最高的員工的名字和工作崗位⑤顯示工資高于平均工資的員工信息⑥顯…

Python爬蟲(5)靜態頁面抓取實戰:requests庫請求頭配置與反反爬策略詳解

目錄 一、背景與需求?二、靜態頁面抓取的核心流程?三、requests庫基礎與請求頭配置?3.1 安裝與基本請求3.2 請求頭核心參數解析?3.3 自定義請求頭實戰 四、實戰案例:抓取豆瓣讀書Top250?1. 目標?2. 代碼實現3. 技術要點? 五、高階技巧與反反爬策略?5.1 動態…

HTML給圖片居中

在不同的布局場景下&#xff0c;讓 <img> 元素居中的方法有所不同。下面為你介紹幾種常見的居中方式 1. 塊級元素下的水平居中 如果 <img> 元素是塊級元素&#xff08;可以通過 display: block 設置&#xff09;&#xff0c;可以使用 margin: 0 auto 來實現水平居…

【高頻考點精講】前端構建工具對比:Webpack、Vite、Rollup和Parcel

前端構建工具大亂斗:Webpack、Vite、Rollup和Parcel誰是你的菜? 【初級】前端開發工程師面試100題(一) 【初級】前端開發工程師面試100題(二) 【初級】前端開發工程師的面試100題(速記版) 最近在后臺收到不少同學提問:“老李啊,現在前端構建工具這么多,我該選哪個?…

趕緊收藏!教您如何用 GitHub 賬號,獲取永久免費的 Docker 容器!!快速搭建我們的網站/應用!

文章目錄 ?? 介紹 ???? 演示環境 ???? 永久免費的 Docker 容器 ???? 注冊與登錄? 創建 Docker 容器?? 部署你的網站?? 注意事項?? 使用場景?? 相關鏈接 ???? 介紹 ?? 還在為搭建個人網站尋找免費方案而煩惱? 今天發現一個寶藏平臺!只需一個 Git…