基于 Python 的自然語言處理系列(82):Transformer Reinforcement Learning

🔗 本文所用工具:trltransformerspeftbitsandbytes
📘 官方文檔參考:https://huggingface.co/docs/trl

一、引言:從有監督微調到 RLHF 全流程

????????隨著語言大模型的發展,如何在大規模預訓練模型基礎上更精細地對齊人類偏好,成為了研究與應用的熱點。本文將介紹一套完整的 RLHF(Reinforcement Learning with Human Feedback)訓練流程,基于 Hugging Face 推出的 trl 庫,從 SFT(Supervised Fine-tuning)、RM(Reward Modeling)、到 PPO(Proximal Policy Optimization)三大階段,逐步實現對 Transformer 模型的強化學習優化。

????????本篇聚焦于 SFT 階段的實現,并以 Hugging Face 提供的 instruction-dataset 為例,介紹如何使用 trl 和 PEFT(參數高效微調)技術訓練一個高效對齊指令的語言模型。

二、安裝與環境準備

????????確保安裝以下庫(建議使用 PyTorch + CUDA 環境):

pip install trl transformers datasets peft bitsandbytes accelerate

三、加載并準備數據集

????????本例使用 HuggingFaceH4 團隊整理的 instruction-dataset

from datasets import load_datasetdataset = load_dataset("HuggingFaceH4/instruction-dataset")
dataset = dataset.remove_columns("meta")  # 移除無用字段
dataset

四、構建模型及量化配置(4-bit)

????????使用 BitsAndBytesConfig 對模型進行 4-bit 量化,可大幅降低顯存占用:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_trainingmodel_name = "lmsys/fastchat-t5-3b-v1.0"bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,quantization_config=bnb_config
)base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)

五、注入 LoRA 參數高效微調機制

????????首先識別所有 4-bit 線性模塊并定義 LoRA 參數配置:

import bitsandbytes as bnb
from peft import get_peft_model, LoraConfigdef find_all_linear_names(model):cls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split(".")lora_module_names.add(names[0] if len(names) == 1 else names[-1])return list(lora_module_names)peft_config = LoraConfig(r=128,lora_alpha=16,target_modules=find_all_linear_names(base_model),lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)base_model = get_peft_model(base_model, peft_config)

????????打印可訓練參數占比:

def print_trainable_parameters(model):trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)total = sum(p.numel() for p in model.parameters())print(f"Trainable params: {trainable} / {total} ({trainable / total:.2%})")print_trainable_parameters(base_model)

六、定義 Prompt 格式化函數

????????將數據集中的 promptcompletion 格式化為統一格式:

def formatting_prompts_func(example):return [f"### Input: ```{prompt}```\n ### Output: {completion}"for prompt, completion in zip(example["prompt"], example["completion"])]

七、訓練參數設置與 SFTTrainer 訓練器

????????使用 SFTTrainer 執行指令微調訓練,支持 gradient checkpointing、cosine 學習率調度等高級策略:

from transformers import TrainingArguments
from trl import SFTTraineroutput_dir = "./results"training_args = TrainingArguments(output_dir=output_dir,per_device_train_batch_size=4,gradient_accumulation_steps=4,gradient_checkpointing=True,max_grad_norm=0.3,num_train_epochs=15,learning_rate=2e-4,bf16=True,save_total_limit=3,logging_steps=10,optim="paged_adamw_32bit",lr_scheduler_type="cosine",warmup_ratio=0.05,
)tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"trainer = SFTTrainer(model=base_model,train_dataset=dataset,tokenizer=tokenizer,max_seq_length=2048,formatting_func=formatting_prompts_func,args=training_args
)

????????執行訓練:

trainer.train()
trainer.save_model(output_dir)

????????保存最終模型權重與 tokenizer:

import os
final_output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)

八、小結與展望

????????通過本文,我們使用 trl 工具鏈完成了 RLHF 的第一階段:SFT 有監督微調。你可以根據項目實際需求,替換為自定義數據集或更大規模模型。后續步驟(RM 獎勵建模 + PPO 策略優化)將在下一篇繼續介紹。

📌 下一篇預告

????????📘《基于 Python 的自然語言處理系列(83):InstructGPT》

????????敬請期待!

如果你覺得這篇博文對你有幫助,請點贊、收藏、關注我,并且可以打賞支持我!

歡迎關注我的后續博文,我將分享更多關于人工智能、自然語言處理和計算機視覺的精彩內容。

謝謝大家的支持!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/80227.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/80227.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/80227.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JAVA猜數小游戲

import java.util.Random; import java.util.Scanner;public class HelloWorld {public static void main(String[] args) {Random rnew Random();int luck_number r.nextInt(100)1;while (true){System.out.println("輸入猜數字");Scanner sc new Scanner(System…

GPU渲染階段介紹+Shader基礎結構實現

GPU是什么 (CPU)Center Processing Unit:邏輯編程 (GPU)Graphics Processing Unit:圖形處理(矩陣運算,數據公式運算,光柵化) 渲染管線 渲染管線也稱為渲染流水線&#x…

Spring Boot + MyBatis 動態字段更新方法

在Spring Boot和MyBatis中,實現動態更新不固定字段的步驟如下: 方法一:使用MyBatis動態SQL(適合字段允許為null的場景) 定義實體類 包含所有可能被更新的字段。 Mapper接口 定義更新方法,參數為實體對象&…

單例模式:確保唯一實例的設計模式

單例模式:確保唯一實例的設計模式 一、模式核心:保證類僅有一個實例并提供全局訪問點 在軟件開發中,有些類需要確保只有一個實例(如系統配置類、日志管理器),避免因多個實例導致狀態混亂或資源浪費。 單…

UnoCSS原子CSS引擎-前端福音

UnoCSS是一款原子化的即時按需 CSS 引擎,其中沒有核心實用程序,所有功能都是通過預設提供的。默認情況下UnoCSS應用通過預設來實現相關功能。 UnoCSS中文文檔: https://www.unocss.com.cn 前有很多種原子化的框架,例如 Tailwind…

【Qwen2.5-VL 踩坑記錄】本地 + 海外賬號和國內賬號的 API 調用區別(阿里云百煉平臺)

API 調用 阿里云百煉平臺的海內外 API 的區別: 海外版:需要進行 API 基礎 URL 設置國內版:無需設置。 本人的服務器在香港,采用海外版的 API 時,需要進行如下API端點配置 / API基礎URL設置 / API客戶端配置&#xf…

C語言筆記(鵬哥)上課板書+課件匯總(結構體)-----數據結構常用

結構體 目錄: 1、結構體類型聲明 2、結構體變量的創建和初始化 3、結構體成員訪問操作符 4、結構體內存對齊*****(重要指數五顆星) 5、結構體傳參 6、結構體實現位段 一、結構體類型聲明 其實在指針中我們已經講解了一些結構體內容了&…

UV: Python包和項目管理器(從入門到不放棄教程)

目錄 UV: Python包和項目管理器(從入門到不放棄教程)1. 為什么用uv,而不是conda或者pip2. 安裝uv(Windows)2.1 powershell下載2.2 winget下載2.3 直接下載安裝包 3. uv教程3.1 創建虛擬環境 (uv venv) 4. uvx5. 此pip非…

網絡開發基礎(游戲方向)之 概念名詞

前言 1、一款網絡游戲分為客戶端和服務端兩個部分,客戶端程序運行在用戶的電腦或手機上,服務端程序運行在游戲運營商的服務器上。 2、客戶端和服務端之間,服務端和服務端之間一般都是使用TCP網絡通信。客戶端和客戶端之間通過服務端的消息轉…

java將pdf轉換成word

1、jar包準備 在項目中新增lib目錄&#xff0c;并將如下兩個文件放入lib目錄下 aspose-words-15.8.0-jdk16.jar aspose-pdf-22.9.jar 2、pom.xml配置 <dependency><groupId>com.aspose</groupId><artifactId>aspose-pdf</artifactId><versi…

【C/C++】插件機制:基于工廠函數的動態插件加載

本文介紹了如何通過 C 的 工廠函數、動態庫&#xff08;.so 文件&#xff09;和 dlopen / dlsym 實現插件機制。這個機制允許程序在運行時動態加載和調用插件&#xff0c;而無需在編譯時知道插件的具體類型。 一、 動態插件機制 在現代 C 中&#xff0c;插件機制廣泛應用于需要…

【音視頻】AAC-ADTS分析

AAC-ADTS 格式分析 AAC?頻格式&#xff1a;Advanced Audio Coding(?級?頻解碼)&#xff0c;是?種由MPEG-4標準定義的有損?頻壓縮格式&#xff0c;由Fraunhofer發展&#xff0c;Dolby, Sony和AT&T是主 要的貢獻者。 ADIF&#xff1a;Audio Data Interchange Format ?…

機器學習 Day12 集成學習簡單介紹

1.集成學習概述 1.1. 什么是集成學習 集成學習是一種通過組合多個模型來提高預測性能的機器學習方法。它類似于&#xff1a; 超級個體 vs 弱者聯盟 單個復雜模型(如9次多項式函數)可能能力過強但容易過擬合 組合多個簡單模型(如一堆1次函數)可以增強能力而不易過擬合 集成…

通過爬蟲方式實現頭條號發布視頻(2025年4月)

1、將真實的cookie貼到代碼目錄中toutiaohao_cookie.txt文件里,修改python代碼里的user_agent和video_path, cover_path等變量的值,最后運行python腳本即可; 2、運行之前根據import提示安裝一些常見依賴,比如requests等; 3、2025年4月份最新版; 代碼如下: import js…

Linux ssh免密登陸設置

使用 ssh-copy-id 命令來設置 SSH 免密登錄&#xff0c;并確保所有相關文件和目錄權限正確設置&#xff0c;可以按照以下步驟進行&#xff1a; 步驟 1&#xff1a;在源服務器&#xff08;198.120.1.109&#xff09;生成 SSH 密鑰對 如果還沒有生成 SSH 密鑰對&#xff0c;首先…

《讓機器人讀懂你的心:情感分析技術融合奧秘》

機器人早已不再局限于執行簡單機械的任務&#xff0c;人們期望它們能像人類伙伴一樣&#xff0c;理解我們的喜怒哀樂&#xff0c;實現更自然、溫暖的互動。情感分析技術&#xff0c;正是賦予機器人這種“理解人類情緒”能力的關鍵鑰匙&#xff0c;它的融入將徹底革新機器人與人…

Linux筆記---進程間通信:匿名管道

1. 管道通信 1.1 管道的概念與分類 管道&#xff08;Pipe&#xff09; 是進程間通信&#xff08;IPC&#xff09;的一種基礎機制&#xff0c;主要用于在具有親緣關系的進程&#xff08;如父子進程、兄弟進程&#xff09;之間傳遞數據&#xff0c;其核心特性是通過內核緩沖區實…

Ollama API 應用指南

1. 基礎信息 默認地址: http://localhost:11434/api數據格式: application/json支持方法: POST&#xff08;主要&#xff09;、GET&#xff08;部分接口&#xff09; 2. 模型管理 API (1) 列出本地模型 端點: GET /api/tags功能: 獲取已下載的模型列表。示例:curl http://lo…

【OSCP-vulnhub】Raven-2

目錄 端口掃描 本地/etc/hosts文件解析 目錄掃描&#xff1a; 第一個flag 利用msf下載exp flag2 flag3 Mysql登錄 查看mysql的運行權限 MySql提權&#xff1a;UDF 查看數據庫寫入條件 查看插件目錄 查看是否可以遠程登錄 gcc編譯.o文件 創建so文件 創建臨時監聽…

Podman Desktop:現代輕量容器管理利器(Podman與Docker)

前言 什么是 Podman Desktop&#xff1f; Podman Desktop 是基于 Podman CLI 的圖形化開源容器管理工具&#xff0c;運行在 Windows&#xff08;或 macOS&#xff09;上&#xff0c;默認集成 Fedora Linux&#xff08;WSL 2 環境&#xff09;。它提供與 Docker 類似的使用體驗…