基于 Python 的自然語言處理系列(83):InstructGPT 原理與實現

📌 論文地址:Training language models to follow instructions with human feedback
💻 參考項目:instructGOOSE

📷 模型架構圖:

一、引言:為什么需要 InstructGPT?

????????傳統的語言模型往往依賴于“最大似然訓練”,學會如何生成符合語法的文本,但卻不一定符合人類的指令意圖。OpenAI 提出的 InstructGPT 是一種結合 人類反饋監督 + 強化學習(RLHF) 的新訓練范式,其目標是使語言模型更能“聽人話”。

????????????????InstructGPT 的三階段訓練流程如下:

  1. SFT(Supervised Fine-tuning):使用人工標注的指令-回復數據進行有監督微調。

  2. RM(Reward Modeling):讓人工對模型生成的多個候選回復打分,從而訓練一個獎勵模型。

  3. PPO(Proximal Policy Optimization):使用 RL 算法訓練語言模型,使其生成的回復最大化獎勵模型的得分。

????????本篇將結合 instructGOOSE 項目,對上述三階段進行端到端復現,使用的數據集為 IMDb 影評文本,語言模型為 GPT-2

二、準備工作:環境與設備

# 安裝依賴
# pip3 install instruct_gooseimport os
import torchfrom datasets import load_dataset
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm# 設置 GPU 設備(如使用 Colab 建議 comment 掉代理配置)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

三、加載 IMDb 數據集并構建 DataLoader

dataset = load_dataset("imdb", split="train")# 為演示快速收斂,僅使用前 10 條數據
dataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10])train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

四、加載 GPT-2 模型與 InstructGPT 工具鏈

from transformers import AutoTokenizer, AutoModelForCausalLM
from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_modelmodel_name_or_path = "gpt2"# 加載主模型與獎勵模型
model_base = AutoModelForCausalLM.from_pretrained(model_name_or_path)
reward_model = RewardModel(model_name_or_path)# 加載 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
eos_token_id = tokenizer.eos_token_id

五、創建 RL 模型代理與參考模型

# 構造 Agent(語言模型 + Value 網絡 + 采樣接口)
model = Agent(model_base)
ref_model = create_reference_model(model)

六、訓練配置與 RLHFTrainer 初始化

max_new_tokens = 20
generation_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": eos_token_id,"max_new_tokens": max_new_tokens
}config = RLHFConfig()  # 可使用默認參數trainer = RLHFTrainer(model, ref_model, config)

七、基于 PPO 的 InstructGPT 強化訓練

from torch import optimoptimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 3for epoch in range(num_epochs):for step, batch in enumerate(tqdm(train_dataloader)):# Step 1: 編碼輸入inputs = tokenizer(batch["text"],padding=True,truncation=True,return_tensors="pt")inputs = {k: v.to(device) for k, v in inputs.items()}# Step 2: 使用主模型生成回復response_ids = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"],**generation_kwargs)response_ids = response_ids[:, -max_new_tokens:]response_attention_mask = torch.ones_like(response_ids)# Step 3: 拼接 query + response,使用 Reward Model 評估得分with torch.no_grad():input_pairs = torch.stack([torch.cat([q, r], dim=0)for q, r in zip(inputs["input_ids"], response_ids)]).to(device)rewards = reward_model(input_pairs)# Step 4: 計算 PPO 損失并反向傳播loss = trainer.compute_loss(query_ids=inputs["input_ids"],query_attention_mask=inputs["attention_mask"],response_ids=response_ids,response_attention_mask=response_attention_mask,rewards=rewards)optimizer.zero_grad()loss.backward()optimizer.step()print(f"[Epoch {epoch+1}] Loss = {loss.item():.4f}")

八、推理測試與結果展示

# 輸入一句文本進行測試
input_text = dataset[0]['text']
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)output = model_base.generate(input_ids, max_length=256,num_beams=5, no_repeat_ngram_size=2,top_k=50, top_p=0.95, temperature=0.7
)generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("🧠 模型生成結果:\n", generated_text)

九、總結

????????本篇我們復現了 InstructGPT 的核心訓練框架,依賴于三大模塊:

  • 語言模型(GPT2);

  • 獎勵模型(RewardModel);

  • 強化訓練器(RLHFTrainer + PPO loss)。

????????通過引入人類反饋偏好作為優化目標,InstructGPT 展現出更強的任務理解與指令遵循能力,已經成為 ChatGPT 訓練體系的核心組成部分之一。

🔮 下一篇預告

????????📘《基于 Python 的自然語言處理系列(84):SFT原理與實踐》

如果你覺得這篇博文對你有幫助,請點贊、收藏、關注我,并且可以打賞支持我!

歡迎關注我的后續博文,我將分享更多關于人工智能、自然語言處理和計算機視覺的精彩內容。

謝謝大家的支持!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902540.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902540.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902540.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

零基礎入門 Verilog VHDL:在線仿真與 FPGA 實戰全流程指南

摘要 本文面向零基礎讀者,全面詳解 Verilog 與 VHDL 兩大主流硬件描述語言(HDL)的核心概念、典型用法及開發流程。文章在淺顯易懂的語言下,配合多組可在線驗證的示例代碼、PlantUML 電路結構圖,讓你在 EDA Playground 上動手體驗數字電路設計與仿真,并深入了解從 HDL 編寫…

Kubernetes控制平面組件:API Server詳解(二)

云原生學習路線導航頁(持續更新中) kubernetes學習系列快捷鏈接 Kubernetes架構原則和對象設計(一)Kubernetes架構原則和對象設計(二)Kubernetes架構原則和對象設計(三)Kubernetes控…

云服務器存儲空間不足導致的docker image運行失敗或Not enough space in /var/cache/apt/archives

最近遇到了兩次空間不足導致docker實例下的mongodb運行失敗的問題。 排查錯誤 首先用nettools看下mongodb端口有沒有被占用: sudo apt install net-tools netstat --all --program | grep 27017 原因和解決方案 系統日志文件太大 一般情況下日志文件不會很大…

爬蟲學習——下載文件和圖片、模擬登錄方式進行信息獲取

一、下載文件和圖片 Scrapy中有兩個類用于專門下載文件和圖片,FilesPipeline和ImagesPipeline,其本質就是一個專門的下載器,其使用的方式就是將文件或圖片的url傳給它(eg:item[“file_urls”])。使用之前需要在settings.py文件中對其進行聲明…

拒絕用電“盲人摸象”,體驗智能微斷的無縫升級

🌟 為什么需要智能微型斷路器? 傳統斷路器只能被動保護電路,而安科瑞智能微型斷路器不僅能實時監測用電數據,還能遠程控制、主動預警,堪稱用電安全的“全能衛士”!無論是家庭、工廠還是商業樓宇&#xff0…

如何優雅地為 Axios 配置失敗重試與最大嘗試次數

在 Vue 3 中,除了使用自定義的 useRequest 鉤子函數外,還可以通過 axios 的攔截器 或 axios-retry 插件實現接口請求失敗后的重試邏輯。以下是兩種具體方案的實現方式: 方案一:使用 axios 攔截器實現重試 實現步驟: 通…

【Leetcode刷題隨筆】242.有效的字母異位詞

1. 題目描述 給定兩個僅包含小寫字母的字符串 s 和 t ,編寫一個函數來判斷 t 是否是 s 的 字母異位詞。 字母異位詞定義:兩個字符串包含的字母種類和數量完全相同,但順序可以不同(例如 “listen” 和 “silent”)。 …

示例:spring xml+注解混合配置

以下是一個 Spring XML 注解的混合配置示例,結合了 XML 的基礎設施配置(如數據源、事務管理器)和注解的便捷性(如依賴注入、事務聲明)。所有業務層代碼通過注解簡化,但核心配置仍通過 XML 管理。 1. 項目結…

Crawl4AI:打破數據孤島,開啟大語言模型的實時智能新時代

當大語言模型遇見數據饑渴癥 在人工智能的競技場上,大語言模型(LLMs)正以驚人的速度進化,但其認知能力的躍升始終面臨一個根本性挑戰——如何持續獲取新鮮、結構化、高相關性的數據。傳統數據供給方式如同輸血式營養支持&#xff…

【機器學習-周總結】-第4周

以下是本周學習內容的整理總結,從技術學習、實戰應用到科研輔助技能三個方面歸納: 文章目錄 📘 一、技術學習模塊:TCN 基礎知識與結構理解🔹 博客1:【時序預測05】– TCN(Temporal Convolutiona…

Mysql--基礎知識點--79.1--雙主架構如何避免回環復制

1 避免回環過程 在MySQL雙主架構中,GTID(全局事務標識符)通過以下流程避免數據回環: 1 事務提交與GTID生成 在Master1節點,事務提交時生成一個全局唯一的GTID(如3E11FA47-71CA-11E1-9E33-C80AA9429562:2…

安寶特科技 | AR眼鏡在安保與安防領域的創新應用及前景

隨著科技的不斷進步,增強現實(AR)技術逐漸在多個領域展現出其獨特的優勢,尤其是在安保和安防方面。AR眼鏡憑借其先進的功能,在機場、車站、海關、港口、工廠、園區、消防局和警察局等行業中為安保人員提供了更為高效、…

Linux第十講:進程間通信IPC

Linux第十講:進程間通信IPC 1.進程間通信介紹1.1什么是進程間通信1.2為什么要進程間通信1.3怎么進行進程間通信 2.管道2.1理解管道2.2匿名管道的實現代碼2.3管道的五種特性2.3.1匿名管道,只能用來進行具有血緣關系的進程進行通信(通常是父子)2.3.2管道文…

微信小程序通過mqtt控制esp32

目錄 1.注冊巴法云 2.設備連接mqtt 3.微信小程序 備注 本文esp32用的是MicroPython固件,MQTT服務用的是巴法云。 本文參考巴法云官方教程:https://bemfa.blog.csdn.net/article/details/115282152 1.注冊巴法云 注冊登陸并新建一個topic&#xff…

SQLMesh隔離系統深度實踐指南:動態模式映射與跨環境計算復用

在數據安全與開發效率的雙重壓力下,SQLMesh通過動態模式映射、跨環境計算復用和元數據隔離機制三大核心技術,完美解決了生產與非生產環境的數據壁壘問題。本文提供從環境配置到生產部署的完整實施框架,助您構建安全、高效、可擴展的數據工程體…

Spring Data詳解:簡化數據訪問層的開發實踐

1. 什么是Spring Data? Spring Data 是Spring生態中用于簡化數據訪問層(DAO)開發的核心模塊,其目標是提供統一的編程模型,支持關系型數據庫(如MySQL)、NoSQL(如MongoDB)…

15 nginx 中默認的 proxy_buffering 導致基于 http 的流式響應存在 buffer, 以 4kb 一批次返回

前言 這也是最近碰到的一個問題 直連 流式 http 服務, 發現 流式響應正常, 0.1 秒接收到一個響應 但是 經過 nginx 代理一層之后, 就發現了 類似于緩沖的效果, 1秒接收到 10個響應 最終 調試 發現是 nginx 的 proxy_buffering 配置引起的 然后 更新 proxy_buffering 為…

源超長視頻生成模型:FramePack

FramePack 是一種下一幀(下一幀部分)預測神經網絡結構,可以逐步生成視頻。 FramePack 將輸入上下文壓縮為固定長度,使得生成工作量與視頻長度無關。即使在筆記本電腦的 GPU 上,FramePack 也能處理大量幀,甚…

第6次課 貪心算法 A

向日葵朝著太陽轉動,時刻追求自身成長的最大可能。 貪心策略在一輪輪的簡單選擇中,逐步導向最佳答案。 課堂學習 引入 貪心算法(英語:greedy algorithm),是用計算機來模擬一個「貪心」的人做出決策的過程…

Windows使用SonarQube時啟動腳本自動關閉

一、解決的問題 Windows使用SonarQube時啟動腳本自動關閉,并發生報錯: ERROR: Elasticsearch did not exit normally - check the logs at E:\Inori_Code\Year3\SE\sonarqube-25.2.0.102705\sonarqube-25.2.0.102705\logs\sonarqube.log ERROR: Elastic…