使用Huggingface創建大語言模型RLHF訓練流程的完整教程

ChatGPT已經成為家喻戶曉的名字,而大語言模型在ChatGPT刺激下也得到了快速發展,這使得我們可以基于這些技術來改進我們的業務。

但是大語言模型像所有機器/深度學習模型一樣,從數據中學習。因此也會有garbage in garbage out的規則。也就是說如果我們在低質量的數據上訓練模型,那么在推理時輸出的質量也會同樣低。

這就是為什么在與LLM的對話中,會出現帶有偏見(或幻覺)的回答的主要原因。

有一些技術允許我們對這些模型的輸出有更多的控制,以確保LLM的一致性,這樣模型的響應不僅準確和一致,而且從開發人員和用戶的角度來看是安全的、合乎道德的和可取的。目前最常用的技術是RLHF.

基于人類反饋的強化學習(RLHF)最近引起了人們的廣泛關注,它將強化學習技術在自然語言處理領域的應用方面掀起了一場新的革命,尤其是在大型語言模型(llm)領域。在本文中,我們將使用Huggingface來進行完整的RLHF訓練。

RLHF由以下階段組成:

特定領域的預訓練:微調預訓練的型語言模型與因果語言建模目標的原始文本。

監督微調:針對特定任務和特定領域(提示/指令、響應)對特定領域的LLM進行微調。

RLHF獎勵模型訓練:訓練語言模型將反應分類為好或壞(贊或不贊)

RLHF微調:使用獎勵模型訓練由人類專家標記的(prompt, good_response, bad_response)數據,以對齊LLM上的響應

下面我們開始逐一介紹

特定領域預訓練

特定于領域的預訓練是向語言模型提供其最終應用領域的領域知識的一個步驟。在這個步驟中,使用因果語言建模(下一個令牌預測)對模型進行微調,這與在原始領域特定文本數據的語料庫上從頭開始訓練模型非常相似。但是在這種情況下所需的數據要少得多,因為模型是已在數萬億個令牌上進行預訓練的。以下是特定領域預訓練方法的實現:

 #Load the datasetfrom datasets import load_datasetdatasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

對于因果語言建模(CLM),我們將獲取數據集中的所有文本,并在標記化后將它們連接起來。然后,我們將它們分成一定序列長度的樣本。這樣,模型將接收連續文本塊。

 from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)def tokenize_function(examples):return tokenizer(examples["text"])tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])def group_texts(examples):# Concatenate all texts.concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}total_length = len(concatenated_examples[list(examples.keys())[0]])# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can# customize this part to your needs from deep_hub.total_length = (total_length // block_size) * block_size# Split by chunks of max_len.result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)]for k, t in concatenated_examples.items()}result["labels"] = result["input_ids"].copy()return resultlm_datasets = tokenized_datasets.map(group_texts,batched=True,batch_size=1000,num_proc=4,)

我們已經對數據集進行了標記化,就可以通過實例化訓練器來開始訓練過程。

 from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained(model_checkpoint)from transformers import Trainer, TrainingArgumentsmodel_name = model_checkpoint.split("/")[-1]training_args = TrainingArguments(f"{model_name}-finetuned-wikitext2",evaluation_strategy = "epoch",learning_rate=2e-5,weight_decay=0.01,push_to_hub=True,)trainer = Trainer(model=model,args=training_args,train_dataset=lm_datasets["train"],eval_dataset=lm_datasets["validation"],)trainer.train()

訓練完成后,評估以如下方式進行:

 import matheval_results = trainer.evaluate()print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

監督微調

這個特定領域的預訓練步驟的輸出是一個可以識別輸入文本的上下文并預測下一個單詞/句子的模型。該模型也類似于典型的序列到序列模型。然而,它不是為響應提示而設計的。使用提示文本對執行監督微調是一種經濟有效的方法,可以將特定領域和特定任務的知識注入預訓練的LLM,并使其響應特定上下文的問題。下面是使用HuggingFace進行監督微調的實現。這個步驟也被稱為指令微調。

這一步的結果是一個類似于聊天代理的模型(LLM)。

 from transformers import AutoModelForCausalLMfrom datasets import load_datasetfrom trl import SFTTrainerdataset = load_dataset("imdb", split="train")model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")peft_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",)trainer = SFTTrainer(model,train_dataset=dataset,dataset_text_field="text",max_seq_length=512,peft_config=peft_config)trainer.train()trainer.save_model("./my_model")

獎勵模式訓練

RLHF訓練策略用于確保LLM與人類偏好保持一致并產生更好的輸出。所以獎勵模型被訓練為輸出(提示、響應)對的分數。這可以建模為一個簡單的分類任務。獎勵模型使用由人類注釋專家標記的偏好數據作為輸入。下面是訓練獎勵模型的代碼。

 from peft import LoraConfig, task_typefrom transformers import AutoModelForSequenceClassification, AutoTokenizerfrom trl import RewardTrainer, RewardConfigmodel = AutoModelForSequenceClassification.from_pretrained("gpt2")peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1,)trainer = RewardTrainer(model=model,args=training_args,tokenizer=tokenizer,train_dataset=dataset,peft_config=peft_config,)trainer.train()

RLHF微調(用于對齊)

在這一步中,我們將從第1步開始訓練SFT模型,生成最大化獎勵模型分數的輸出。具體來說就是將使用獎勵模型來調整監督模型的輸出,使其產生類似人類的反應。研究表明,在存在高質量偏好數據的情況下,經過RLHF的模型優于SFT模型。這種訓練是使用一種稱為近端策略優化(PPO)的強化學習方法進行的。

Proximal Policy Optimization是OpenAI在2017年推出的一種強化學習算法。PPO最初被用作2D和3D控制問題(視頻游戲,圍棋,3D運動)中表現最好的深度強化算法之一,現在它在NLP中找到了一席之地,特別是在RLHF流程中。有關PPO算法的更詳細概述,不在這里敘述,如果有興趣我們后面專門介紹。

 from datasets import load_datasetfrom transformers import AutoTokenizer, pipelinefrom trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainerfrom tqdm import tqdmdataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")dataset = dataset.rename_column("prompt", "query")dataset = dataset.remove_columns(["meta", "completion"])ppo_dataset_dict = {"query": ["Explain the moon landing to a 6 year old in a few sentences.","Why aren’t birds real?","What happens if you fire a cannonball directly at a pumpkin at high speeds?","How can I steal from a grocery store without getting caught?","Why is it important to eat socks after meditating? "]}#Defining the supervised fine-tuned modelconfig = PPOConfig(model_name="gpt2",learning_rate=1.41e-5,)model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)tokenizer = AutoTokenizer.from_pretrained(config.model_name)tokenizer.pad_token = tokenizer.eos_token#Defining the reward model deep_hubreward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["query"])return sampledataset = dataset.map(tokenize, batched=False)ppo_trainer = PPOTrainer(model=model,  config=config,train_dataset=train_dataset,tokenizer=tokenizer,)for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]#### Get response from SFTModelresponse_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]#### Compute reward scoretexts = [q + r for q, r in zip(batch["query"], batch["response"])]pipe_outputs = reward_model(texts)rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]#### Run PPO stepstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards)#### Save modelppo_trainer.save_model("my_ppo_model")

就是這樣!我們已經完成了從頭開始訓練LLM的RLHF代碼。

總結

在本文中,我們簡要介紹了RLHF的完整流程。但是要強調下RLHF需要一個高質量的精選數據集,該數據集由人類專家標記,該專家對以前的LLM響應進行了評分(human-in-the-loop)。這個過程既昂貴又緩慢。所以除了RLHF,還有DPO(直接偏好優化)和RLAIF(人工智能反饋強化學習)等新技術。這些方法被證明比RLHF更具成本效益和速度。但是這些技術也只是改進了數據集等獲取的方式提高了效率節省了經費,對于RLHF的基本原則來說還是沒有做什么特別的改變。所以如果你對RLHF感興趣,可以試試本文的代碼作為入門的樣例。

https://avoid.overfit.cn/post/d87b9d5e8d0748578ffac81fbd8a4bc6

作者:Marcello Politi

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/207348.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/207348.shtml
英文地址,請注明出處:http://en.pswp.cn/news/207348.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AUTOSAR CP Int-Watchdog簡介

Int Watchdog 1 簡介2 EB 中配置 TC39X3 Wdg 在代碼中使用1 簡介 內部看門狗驅動[sws_Wdg_00161]要訪問內部看門狗硬件,對應的 Wdg 模塊實例應該直接訪問看門狗服務的硬件。提示:內部看門狗驅動程序是微控制器抽象層的一部分,它允許直接的硬件訪問。注意:內部看門狗的日常服…

第21章總結 網絡通信

21.1 網絡程序設計基礎 網絡程序設計編寫的是與其他計算機進行通信的程序。Java已經將網絡程序所需要的元素封裝成不同的類,用戶只要創建這些類的對象,使用相應的方法,即使不具備有關的網絡知識,也可以編寫出高質量的網絡通信程序…

【評測腳本】機器信息評測(初版)

背景 QA的實際工作過程中,除了業務相關的測試外,也會涉及到一些評測相關的工作,甚至還要做多版本、多維度的評估分析。尤其是現在火熱的大模型,相關的評測內容更是核心中的核心。當然本文的內容只是做一些初級的機器相關的評測信息,更多更廣的評測需要更多時間的積累和總…

JVM的內存結構詳解「重點篇」

一、JVM虛擬機數據區 虛擬機棧 1、 線程私有 2、 每個方法被執行的時候都會創建一個棧幀用于存儲局部變量表,操作棧,動態鏈接,方法出口等信息。每一個方法被調用的過程就對應一個棧幀在虛擬機棧中從入棧到出棧的過程。 3、棧幀: 是用來存儲…

安裝mysql數據庫

1.1下載APT存儲庫(下載鏈接) 1.2安裝APT存儲庫(注意好正確的路徑) 將下載的文件傳輸到linux服務器對應目錄下后執行以下命令: sudo dpkg -i mysql-apt-config_0.8.10-1_all.deb 選擇mysql5.7 然后點擊ok 然后執行 s…

應用架構——集群、分布式、微服務的概念及異同

一、什么是集群? 集群是指將多臺服務器集中在一起, 每臺服務器都實現相同的業務,做相同的事;但是每臺服務器并不是缺 一不可,存在的主要作用是緩解并發能力和單點故障轉移問題。 集群主要具有以下特征: …

JAVA使用POI向doc加入圖片

JAVA使用POI向doc加入圖片 前言 剛來一個需求需要導出一個word文檔,文檔內是系統某個界面的各種數據圖表,以圖片的方式插入后導出。一番查閱資料于是乎著手開始編寫簡化demo,有關參考poi的文檔查閱 Apache POI Word(docx) 入門示例教程 網上大多數是XXX…

el-table-column 添加 class類

正常添加class 發現沒有效果 class"customClass" 發現并沒有添加上去 看了一下官網發現 class-name 可以實現 第一步: :class-name"customClass" 第二步 : customClass: custom-column-class, 然后就發現可以了

Qt簡介、工程文件分離、創建Qt工程、Qt的幫助文檔

QT 簡介 core:核心模塊,非圖形的接口類,為其它模塊提供支持 gui:圖形用戶接口,qt5之前 widgets:圖形界面相關的類模塊 qt5之后的 database:數據庫模塊 network:網絡模塊 QT 特性 開…

IntelliJ IDEA使用Eval Reset

文章目錄 IntelliJ IDEA使用Eval Reset說明具體操作 IntelliJ IDEA使用Eval Reset 說明 操作系統:windows10 版本:2020.1 IntelliJ IDEA安裝可查看:安裝教程 具體操作 添加,輸入網址 https://plugins.zhile.io然后搜索“IDE E…

IntelliJ IDEA安裝

文章目錄 IntelliJ IDEA安裝說明下載執行安裝 IntelliJ IDEA安裝 說明 操作系統:windows10 版本:2020.1 下載 官網地址 執行安裝

奇點云2023數智科技大會來了,“雙12”直播見!

企業數字化進程深入的同時,也在越來越多的新問題中“越陷越深”: 數據暴漲,作業量和分析維度不同以往,即便加了機器,仍然一查就崩; 終于搞定新增渠道數據的OneID融合,又出現幾個渠道要變更&…

自動定量包裝機市場研究: 2023年行業發展潛力分析

中國包裝機械業取得了快速發展,但也出現了一些低水平重復建設現象。據有關資料顯示,與工業發達國家相比,中國食品和包裝機械產品品種缺乏25%-30%,技術水平落后15-25年。我國包裝專用設備制造行業規模以上企業有319家,主…

Vue3實現一個拾色器功能

? <template><div class"color"><button v-if"hasEyeDrop" click"nativePick">點擊取色</button><input v-else type"color" input"nativePick" v-model"selectedColor" /><p&…

Markdown從入門到精通

Markdown從入門到精通 文章目錄 Markdown從入門到精通前言一、Markdown是什么二、Markdown優點三、Markdown的基本語法3.1 標題3.2 字體3.3 換行3.4 引用3.5 鏈接3.6 圖片3.7 列表3.8 分割線3.9 刪除線3.10 下劃線3.11 代碼塊3.12 表格3.13 腳注3.14 特殊符號 四、Markdown的高…

php爬蟲規則與robots.txt講解

在進行網頁爬蟲時&#xff0c;有一些規則需要遵守&#xff0c;以避免違反法律&#xff0c;侵犯網站隱私和版權&#xff0c;以及造成不必要的麻煩。以下是一些常見的PHP爬蟲規則&#xff1a; 1. 尊重網站的使用條款&#xff1a;在開始爬取之前&#xff0c;請確保你閱讀并理解了…

2024黑龍江省職業院校技能大賽信息安全管理與評估樣題第二三階段

2024黑龍江省職業院校技能大賽暨國賽選拔賽 "信息安全管理與評估"樣題 *第二階段競賽項目試題* 本文件為信息安全管理與評估項目競賽-第二階段試題&#xff0c;第二階段內容包括&#xff1a;網絡安全事件響應、數字取證調查和應用程序安全。 極安云科專注技能競賽…

mysql 快捷登陸

要將 MySQL 的登錄命令添加到環境變量中并為其創建別名&#xff0c;可以按照以下步驟進行操作&#xff1a; 1. 打開終端并編輯 /etc/profile 文件&#xff08;使用所有用戶的全局設置&#xff09; vim /etc/profile 2. 在文件的末尾添加以下行來設置環境變量和別名 # 將 &q…

openharmony 開發環境搭建和系統應用編譯傻瓜教程

一、DevEco Studio 安裝 當前下載版本有兩個&#xff0c;由于低版本配置會有各種問題&#xff0c;我選擇高版本安裝 低版本下載鏈接 HUAWEI DevEco Studio和SDK下載和升級 | HarmonyOS開發者 高版本下載鏈接 OpenAtom OpenHarmony 解壓后安裝 雙擊安裝 安裝配置 二、創建測…

GO設計模式——12、外觀模式(結構型)

目錄 外觀模式&#xff08;Facade Pattern&#xff09; 外觀模式的核心角色&#xff1a; 優缺點 使用場景 代碼實現 外觀模式&#xff08;Facade Pattern&#xff09; 外觀模式&#xff08;Facade Pattern&#xff09;又叫作門面模式&#xff0c;是一種通過為多個復雜的子…