大模型增量預訓練新技巧-解決災難性遺忘

大模型增量預訓練新技巧-解決災難性遺忘

機器學習算法與自然語言處理?2024年03月21日 00:02?吉林

以下文章來源于NLP工作站?,作者劉聰NLP

NLP工作站.

AIGC前沿知識分享&落地經驗總結

轉載自 |?NLP工作站

作者 |?劉聰NLP

目前不少開源模型在通用領域具有不錯的效果,但由于缺乏領域數據,往往在一些垂直領域中表現不理想,這時就需要增量預訓練和微調等方法來提高模型的領域能力。

但在領域數據增量預訓練或微調時,很容易出現災難性遺忘現象,也就是學會了垂直領域知識,但忘記了通用領域知識,之前介紹過增量預訓練以及領域大模型訓練技巧,詳見:

  • 如何更好地繼續預訓練-Continue PreTraining

  • 領域大模型-訓練Trick&落地思考

今天給大家帶來一篇增量預訓練方法-Llama-Pro,對LLMs進行Transformer塊擴展后,增量預訓練過程中僅對新增塊進行訓練,有效地進行模型知識注入,并且極大程度地避免災難性遺忘。

圖片

LLaMA Pro: Progressive LLaMA with Block Expansion

 

LLaMA?Pro:?Progressive?LLaMA?with?Block?Expansion
Paper:?https://arxiv.org/abs/2401.02415
Github:?https://github.com/TencentARC/LLaMA-Pro

塊擴展方法

塊擴展,顧名思義,就是在原始模型中每個Transformer塊或者某幾個Transformer塊增加一個Transformer塊,但為了保持擴展后的模型輸出保持不變,需要增加的塊為恒等塊(輸入輸出相同),如下圖所示。

圖片

在構建恒等塊過程中,主要是將多頭注意力層和FFN層中的最后一個線性層(Linear權重置為0變成Zero-Linear,即可保持經過該塊的輸入輸出一致。

PS:論文附錄A中寫了大段的推導公式來證明,在此不做過多介紹。

塊的增加方式是,對原始模型的L個Transformer塊分成N組,每組中包含M=L/N個Transformer塊,對于每組后添加P個恒等塊。代碼實現具體如下:

model?=?AutoModelForCausalLM.from_pretrained(model_path,?torch_dtype=torch.float16)
ckpt?=?model.state_dict()#?original_layers是模型原始層數,layers是模型最后達到層數
split?=?int(original_layers?/?(layers?-?original_layers))layer_cnt?=?0output?=?{}
for?i?in?range(original_layers):for?k?in?ckpt:if?('layers.'?+?str(i)?+?'.')?in?k:output[k.replace(('layers.'?+?str(i)?+?'.'),?('layers.'?+?str(layer_cnt)?+?'.'))]?=?ckpt[k]layer_cnt?+=?1if?(i+1)?%?split?==?0:for?k?in?ckpt:if?('layers.'?+?str(i)?+?'.')?in?k:if?'down_proj'?in?k?or?'o_proj'?in?k:output[k.replace(('layers.'?+?str(i)?+?'.'),?('layers.'?+?str(layer_cnt)?+?'.'))]?=?torch.zeros_like(ckpt[k])else:output[k.replace(('layers.'?+?str(i)?+?'.'),?('layers.'?+?str(layer_cnt)?+?'.'))]?=?ckpt[k]layer_cnt?+=?1assert?layer_cnt==layers
for?k?in?ckpt:if?not?'layers'?in?k:output[k]?=?ckpt[k]torch.save(output,?output_path)

實驗細節

數據由代碼和數學組成,其中代碼數據采用The-Stack-Dedup數據集中Python語言部分共22B Token,數學數據采用Proof-Pile-2數據集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,詳細如下表所示。

圖片

數據分布

基礎模型為LLaMA2-7B模型,通過塊擴展方法將32層模型擴展到40層,其中 P=1,M=4,N=8,每個組從4個Transformer塊擴展到5個Transformer塊。

對于代碼和數學數據進行增量預訓練,批量大小為1024,序列最大長度為4096,預熱比率為6%,學習率為2e-4,采用余弦學習率調度器,BF16混合精度訓練,權重衰減為0.1。使用16個NVIDIA H800 GPU進行了15900個步驟的訓練,大約耗費2830個GPU/小時

ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多個評測數據集中進行評測,可以看出,在保持通用任務能力不下降的情況下,數學和代碼能力較原始LLaMA2-7B模型有很大提升。

圖片

圖片

討論分析

對比塊擴展方法與正常訓練和Lora方法之間的區別,采用TRACE基準利用總體性能(OP)和逆向轉移(BWT)指標進行評估。,如下表所示,塊擴展方法整體提升較大。

圖片

對比塊個數對塊擴展方法的影響,進行了不同個數塊的實驗,并且對比了MoE的方法,訓練損失如下,MoE方法的損失下降程度與添加四個塊相當

圖片

代碼和法律(16.7B)領域數據下進行增量預訓練,在通用任務以及領域任務上比較不同個數塊之間的差異,同時比較擴展塊全部添加到模型底部或頂部之間的差別,如下所示。可以發現塊個數為8時效果最佳,并且不能直接將擴展塊全部堆積在頭部或尾部需要分開插入

圖片

寫在最后

該方法主要通過增加恒定塊擴展模型層數,使模型在增量訓練過程中僅訓練新增層、凍結原始層,保持模型原有能力,防止模型出現災難性遺忘現象。

但有兩點存疑:

  • 目前來說mistral要好于llama,為啥不用mistral進行實驗

  • 不用恒定塊,性能會差多少

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/42029.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/42029.shtml
英文地址,請注明出處:http://en.pswp.cn/web/42029.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

G1 和 CMS

1、CMS CMS(Concurrent Mark Sweep,并發標記清除,是為了解決早期垃圾收集器在執行垃圾回收時導致應用程序暫停時間過長的問題而設計的。 CMS的工作流程主要包括以下幾個階段: 初始標記(Initial Mark)&…

一體化運維監控平臺:賦能各行業用戶運維升級

在當今數字化轉型的大潮中,企業IT系統的復雜性和規模不斷攀升,對運維團隊提出了前所未有的挑戰。如何高效、精準地監控和管理IT基礎設施,確保業務連續性和穩定性,成為所有企業關注的焦點。美信,自2007年成立以來&#…

el-scrollbar實現自動滾動到底部(AI聊天)

目錄 項目背景 實現步驟 實現代碼 完整示例代碼 項目背景 chatGPT聊天消息展示滾動面板,每次用戶輸入提問內容或者ai進行流式回答時需要不斷的滾動到底部確保展示最新的消息。 實現步驟 采用element ui 的el-scrollbar作為聊天消息展示組件。 通過操作dom來實…

端、邊、云三級算力網絡

目錄 端、邊、云三級算力網絡 NPU Arm架構 OpenStack kubernetes k3s輕量級Kubernetes kubernetes和docker區別 DCI(Data Center Interconnect) SD/WAN TF 端、邊、云三級算力網絡 算力網絡從傳統云網融合的角度出發,結合 邊緣計算、網絡云化以及智能控制的優勢,通…

Qt開發 | Qt創建線程 | Qt并發-QtConcurrent

文章目錄 一、Qt創建線程的三種方法二、Qt并發:QtConcurrent介紹三、QtConcurrent run參數說明四、獲取QtConcurrent的返回值五、C其他線程技術介紹 一、Qt創建線程的三種方法 以下是Qt創建線程的三種方法: 方法一:派生于QThread 派生于QThre…

理解算法復雜度:空間復雜度詳解

引言 在計算機科學中,算法復雜度是衡量算法效率的重要指標。時間復雜度和空間復雜度是算法復雜度的兩個主要方面。在這篇博客中,我們將深入探討空間復雜度,了解其定義、常見類型以及如何進行分析。空間復雜度是衡量算法在執行過程中所需內存…

ceph mgr [errno 39] RBD image has snapshots (error deleting image from trash)

ceph mgr 報錯 debug 2024-07-08T09:25:56.512+0000 7f9c63bd2700 0 [rbd_support INFO root] execute_task: task={"sequence": 3, "id": "260b9fee-d567-4301-b7eb-b1fe1b037413", "message": "Removing image replicapool/8…

昇思25天學習打卡營第19天|Diffusion擴散模型

學AI還能贏獎品?每天30分鐘,25天打通AI任督二脈 (qq.com) Diffusion擴散模型 本文基于Hugging Face:The Annotated Diffusion Model一文翻譯遷移而來,同時參考了由淺入深了解Diffusion Model一文。 本教程在Jupyter Notebook上成…

python庫 - missingno

missingno 是一個用于可視化和分析數據集中缺失值的 Python 庫。它提供了一系列簡單而強大的工具,幫助用戶直觀地理解數據中的缺失模式,從而更好地進行數據清洗和預處理。missingno 庫特別適用于數據分析和數據科學項目,尤其是在處理缺失數據…

昇思MindSpore學習筆記5-02生成式--RNN實現情感分類

摘要: 記錄MindSpore AI框架使用RNN網絡對自然語言進行情感分類的過程、步驟和方法。 包括環境準備、下載數據集、數據集加載和預處理、構建模型、模型訓練、模型測試等。 一、概念 情感分類。 RNN網絡模型 實現效果: 輸入: This film is terrible 正…

放大鏡案例

放大鏡 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>商品放大鏡</title><link rel&qu…

如何使用allure生成測試報告

第一步下載安裝JDK1.8&#xff0c;參考鏈接JDK1.8下載、安裝和環境配置教程-CSDN博客 第二步配置allure環境&#xff0c;參考鏈接allure的安裝和使用(windows環境)_allure windows-CSDN博客 第三步&#xff1a; 第四步&#xff1a; pytest 查看目前運行的測試用例有無錯誤 …

如何使用 pytorch 創建一個神經網絡

我已發布在&#xff1a;如何使用 pytorch 創建一個神經網絡 SapientialM.Github.io 構建神經網絡 1 導入所需包 import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms2 檢查GPU是否可用 dev…

ffmpeg濾鏡創建過程

1、創建一個濾鏡圖 AVFilterGraph *filter_graph avfilter_graph_alloc(); 2、創建濾鏡的輸入和輸出 AVFilterInOut *inputs avfilter_inout_alloc(); AVFilterInOut *outputs avfilter_inout_alloc(); 3、每個濾鏡創建上下文 AVFilterContext *filter1_ctx avfilter_…

Yolov10訓練,轉化onnx,推理

yolov10對于大目標的效果好&#xff0c;小目標不好 一、如果你訓練過yolov5&#xff0c;yolov8&#xff0c;的話那么你可以直接用之前的環境就行 目錄 一、如果你訓練過yolov5&#xff0c;yolov8&#xff0c;的話那么你可以直接用之前的環境就行 二、配置好后就可以配置文件…

android webview 遠程調試

打開遠程調試選項 MainActivity super.onCreate(savedInstanceState);// enable Cordova apps to be started in the backgroundBundle extras getIntent().getExtras();if (extras ! null && extras.getBoolean("cdvStartInBackground", false)) {moveT…

前端JS特效第24集:jquery css3實現瀑布流照片墻特效

jquery css3實現瀑布流照片墻特效&#xff0c;先來看看效果&#xff1a; 部分核心的代碼如下(全部代碼在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8" /> <title>jquerycss3實現瀑…

Nginx:負載均衡小專題

運維專題 Nginx&#xff1a;負載均衡小專題 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

在Conda環境中高效使用Kubernetes:跨平臺容器化實踐指南

摘要 Conda 是一個流行的跨平臺包和環境管理器&#xff0c;廣泛用于Python社區。而 Kubernetes 是一個開源的容器編排系統&#xff0c;用于自動化部署、擴展和管理容器化應用程序。本文將探討如何在 Conda 環境中使用 Kubernetes&#xff0c;包括設置 Conda 環境、容器化應用程…

【專項刷題】— 位運算

常見類型介紹&#xff1a; & &#xff1a;有 0 就是 0 | &#xff1a;有 1 就是 1 ^ &#xff1a;相同為 0 &#xff0c;相異為 1 或者 無進位相加給定一個數確定它的二進制位的第x個數是0還是1&#xff1a;將一個數的二進制的第x位改成1&#xff1a;將一個數的二進制的第x…