Transformers 加速的一些常用技巧

Transformers 是一個強大的架構,但模型因其采用的自注意力機制,雖然能夠有效地處理序列數據并捕獲長距離依賴關系,但同時也容易導致在訓練過程中出現OOM(Out of Memory,內存不足)或者達到GPU的運行時限制。

主要是因為

  1. 參數數量龐大:Transformer模型通常包含大量的參數,尤其是在模型層面進行擴展時(例如,增加層數或頭數)。這些參數需要大量的內存來存儲權重和梯度。
  2. 自注意力計算:自注意力機制需要對輸入序列的每個元素與其他所有元素計算其相互關系,導致計算復雜度和內存需求隨著輸入長度的增加而顯著增加。對于非常長的序列,這一點尤其突出。
  3. 激活和中間狀態存儲:在訓練過程中,需要存儲前向傳播中的中間激活狀態,以便于反向傳播時使用。這增加了額外的內存負擔。

為了解決這些問題,我們今天來總結以下一些常用的加速策略

固定長度填充

在處理文本數據時,由于文本序列的長度可能各不相同,但許多機器學習模型(尤其是基于Transformer的模型)需要輸入數據具有固定的尺寸,因此需要對文本序列進行固定長度填充(padding)。

在使用Transformer模型時,填充部分不應影響到模型的學習。因此通常需要使用注意力掩碼(attention mask)來指示模型在自注意力計算時忽略這些填充位置。通過這種固定長度填充和相應的處理方法,可以使得基于Transformer的模型能夠有效地處理不同長度的序列數據。在實際應用中,這種方法是處理文本輸入的常見策略。

 def fixed_pad_sequences(sequences, max_length, padding_value=0):padded_sequences = []for sequence in sequences:if len(sequence) >= max_length:padded_sequence = sequence[:max_length]  # Trim the sequence if it exceeds max_lengthelse:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_sequences.append(padded_sequence)return padded_sequences

這種方式會將所有的序列填充成一個長度,這樣雖然長度相同了,但是因為序列的實際大小本來就不同,同一批次很可能出現有很多填充的情況,所以就出現了動態填充策略。

動態填充是在每個批處理中動態填充輸入序列到最大長度。與固定長度填充不同,在固定長度填充中,所有序列都被填充以匹配整個數據集中最長序列的長度,動態填充根據該批中最長序列的長度單獨填充每個批中的序列。

這樣雖然每個批次的長度是不同的,但是批次內部的長度是相同的,可以加快處理速度。

 def pad_sequences_dynamic(sequences, padding_value=0):max_length = max(len(seq) for seq in sequences)  # Find the maximum length in the sequencespadded_sequences = []for sequence in sequences:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_sequences.append(padded_sequence)return padded_sequences

等長匹配

等長匹配是在訓練或推理過程中將長度相近的序列分組成批處理的過程。等長匹配通過基于序列長度將數據集劃分為桶,然后從這些桶中采樣批次來實現的。

從上圖可以看到,通過等長匹配的策略,減少了填充量,這樣也可以加速計算

 def uniform_length_batching(sequences, batch_size, padding_value=0):# Sort sequences based on their lengthssequences.sort(key=len)# Divide sequences into buckets based on lengthbuckets = [sequences[i:i+batch_size] for i in range(0, len(sequences), batch_size)]# Pad sequences within each bucket to the length of the longest sequence in the bucketpadded_batches = []for bucket in buckets:max_length = len(bucket[-1])  # Get the length of the longest sequence in the bucketpadded_bucket = []for sequence in bucket:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_bucket.append(padded_sequence)padded_batches.append(padded_bucket)return padded_batches

自動混合精度

自動混合精度(AMP)是一種通過使用單精度(float32)和半精度(float16)算法的組合來加速深度學習模型訓練的技術。它利用了現代gpu的功能,與float32相比,使用float16數據類型可以更快地執行計算,同時使用更少的內存。

 import torchfrom torch.cuda.amp import autocast, GradScaler# Define your modelmodel = YourModel()# Define optimizer and loss functionoptimizer = torch.optim.Adam(model.parameters(), lr=1e-3)criterion = torch.nn.CrossEntropyLoss()# Create a GradScaler object for gradient scalingscaler = GradScaler()# Inside the training loopfor inputs, targets in dataloader:# Clear previous gradientsoptimizer.zero_grad()# Cast inputs and targets to the appropriate deviceinputs, targets = inputs.to(device), targets.to(device)# Enable autocasting for forward passwith autocast():# Forward passoutputs = model(inputs)loss = criterion(outputs, targets)# Backward pass# Scale the loss valuescaler.scale(loss).backward()# Update model parametersscaler.step(optimizer)# Update the scale for next iterationscaler.update()

AMP在訓練過程中動態調整計算精度,允許模型在大多數計算中使用float16,同時自動將某些計算提升為float32,以防止下流或溢出等數值不穩定問題。

Fp16 vs Fp32

雙精度(FP64)消耗64位。符號值為1位,指數值為11位,有效精度為52位。

單精度(FP32)消耗32位。符號值為1位,指數值為8位,有效精度為23位。

半精度(FP16)消耗16位。符號值為1位,指數值為5位,有效精度為10位。

所以Fp16可以提高內存節省,并可以大大提高模型訓練的速度。考慮到Fp16的優勢和它在模型使用方面的主導區域,它非常適合推理任務。但是fp16會產生數值精度的損失,導致計算或存儲的值不準確,考慮到這些值的精度至關重要。

另外就是這種優化師針對于分類任務的,對于回歸這種需要精確數值的任務Fp16的表現并不好。

總結

以上這些方法,可以在一定程度上緩解內存不足和計算資源的限制,但是對于大型的模型我們還是需要一個強大的GPU。

https://avoid.overfit.cn/post/7240bee210cd408a90ca04279830040e

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/11200.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/11200.shtml
英文地址,請注明出處:http://en.pswp.cn/web/11200.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AI大模型探索之路-訓練篇22: ChatGLM3微調實戰-從原理到應用的LoRA技術全解

系列篇章💥 AI大模型探索之路-訓練篇1:大語言模型微調基礎認知 AI大模型探索之路-訓練篇2:大語言模型預訓練基礎認知 AI大模型探索之路-訓練篇3:大語言模型全景解讀 AI大模型探索之路-訓練篇4:大語言模型訓練數據集概…

MPLAB X IDE編譯attiny1616工程報錯卻無報錯信息

MPLAB X IDE(XC-8編譯器)編譯報錯,無具體錯誤內容,僅顯示需要xc-8 pro的警告。 內存占用率顯示為81%,未超標。 原因:軟件使用了microchip的bootloader功能。應用程序起始地址(也是bootloader結束地址)設置錯…

社交巨頭:探索Facebook的震撼力量

Facebook作為社交媒體領域的巨頭,不僅在數字化社會中占據著重要地位,更是影響了人們的生活、工作和社交方式。本文將深入探索Facebook的震撼力量,從多個角度解讀其在當今社會中的重要性和影響。 1. 全球用戶覆蓋的壯觀規模 Facebook作為全球…

軟件定義汽車七大典型應用場景

隨著軟件定義汽車典型應用場景的落地,用戶將明顯體驗到汽車從交通工具向智能移動終端的轉變。幾十年前主要用高性能的底盤操穩與動力系統定義一臺好車,幾年前主要用智能化系統與智能交互滿足終端用戶的用車體驗,未來將調度全車傳感器與數據驅…

c 數組遍歷

#include <stdio.h> #include <stdlib.h> int main() { printf(“指針數組練習&#xff01;&#xff01;&#xff01;\n”); /* 數組名就是數組的首地址 數組存在一段連續的內存空間中 */ double score[] {60, 70, 80, 90, 100}; double *ptr_score; i…

docker安裝時報錯:Error: Nothing to do

安裝docker時報以下錯誤 解決方法&#xff1a; 1.下載關于docker的相關依賴環境 yum -y install yum-utils device-mapper-persistent-data lvm22.設置下載Docker的鏡像源 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo3…

FMEA存在的五個主要不足及改進措施——FMEA軟件

免費試用FMEA軟件-免費版-SunFMEA 在制造業和產品設計領域&#xff0c;失效模式與影響分析&#xff08;Failure Modes and Effects Analysis&#xff0c;簡稱FMEA&#xff09;被廣泛運用&#xff0c;用于預防潛在的設計或制造缺陷。然而&#xff0c;盡管FMEA在風險管理方面發揮…

開發者集結號:大灣區 Open Source Day 邀您共探技術前沿

開源技術正以其開放、協作的特性&#xff0c;引領著軟件開發的新潮流&#xff0c;是推動社會進步的重要力量。作為開發者&#xff0c;您是否渴望深入了解開源項目的前沿動態&#xff1f;由ALC深圳與2024中國互聯網發展創新與投資大賽聯合舉辦、FISCO金鏈盟深度參與的大灣區 Ope…

MySQL————創建存儲過程函數

存儲過程使用大綱 有參數傳遞 delimiter $$ 聲明一個名稱為get_student_introduce create procedure add_student_infor( in p_userName VARCHAR(20),in p_phone VARCHAR(11),in p_sex char(2),in p_introduce VARCHAR(255)) 開始操作 BEGIN 撰寫真正在操作DMLDQL都行 INSE…

CSS---復合選擇器、元素顯示模式和背景(三)

一、CSS的復合選擇器 1.1 什么是復合選擇器 在CSS中&#xff0c;可以根據選擇器的類型把選擇器分為基礎選擇器和復合選擇器&#xff0c;復合選擇器是建立在基礎選擇器之上&#xff0c;對基本選擇器進行組合形成的。 復合選擇器是由兩個或多個基礎選擇器連寫組成&#xff0c;它…

SpringBoot3和SpringBoot2分別整合knife4j(openApi)

文章目錄 一、SpringBoot2進行整合knife4j1.1 導入依賴1.2 配置knife4j 配置文件1.3 可以在接口上配置 注解進行信息的配置 二、SpringBoot3 整合kinfe4j(openApi)2.1 導入依賴2.2 yaml配置文件2.3 swagger初始化配置2.4 創建接口 一、SpringBoot2進行整合knife4j 1.1 導入依賴…

【云原生】kubernetes核心組件

引言&#xff1a; Kubernetes 是為運行分布式集群而建立的&#xff0c;分布式系統的本質使得網絡成為 Kubernetes 的核心和必要組成部分&#xff0c;了解 Kubernetes 網絡模型可以使你能夠正確運行、監控和排查應用程序故障。 一、Kubernetes的核心組件 1.1、Master組件 1.1.…

基于Springboot+Vue的Java項目-農產品直賣平臺系統開發實戰(附演示視頻+源碼+LW)

大家好&#xff01;我是程序員一帆&#xff0c;感謝您閱讀本文&#xff0c;歡迎一鍵三連哦。 &#x1f49e;當前專欄&#xff1a;Java畢業設計 精彩專欄推薦&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python畢業設計 &am…

Kubernetes之Headless Services

Kubernetes中的Headless Services&#xff08;無頭服務&#xff09;是一種特殊類型的服務&#xff08;Service&#xff09;定義&#xff0c;它不提供傳統意義上的負載均衡和集群IP地址分配。在無頭服務中&#xff0c;spec.clusterIP 字段被顯式設置為None &#xff0c;Kubernet…

可道云teamOS企業網盤實用插件介紹:實時在線流程圖編輯與分享,用在線流程圖打造數字化工作流程

在使用企業網盤用于日常辦公的情況下&#xff0c;有一些實用的在線小工具能為團隊效率和協作帶來一定的提升。 今天要給大家介紹的可道云teamOS的在線畫流程圖&#xff0c;是很值得介紹的一個在線工具。 在線流程圖&#xff1a;直觀展示&#xff0c;高效便捷 以往我們想要梳理…

FANUC機器人單軸零點標定時提示無法執行零點標定,由于重力補償已啟用,所有機器人軸的脈沖計數必須有效

FANUC機器人單軸零點標定時提示無法執行零點標定,由于重力補償已啟用,所有機器人軸的脈沖計數必須有效 首先,機器人由于長時間斷電未使用,6個軸的編碼器數據全部丟失,上電后報警SRVO-062, 有關SRVO-062故障報警的相關內容可參考以下鏈接: FANUC機器人SRVO-062報警原因分…

LeetCode 2391. 收集垃圾的最少總時間

Problem: 2391. 收集垃圾的最少總時間 問題分解 我們將這個問題分解為以下幾個小問題&#xff1a; 計算每種垃圾&#xff08;金屬、紙、玻璃&#xff09;在每個房子中的數量。確定每種垃圾車最后到達的房子。計算每種垃圾車行駛的總時間。計算每種垃圾車收拾垃圾的總時間。返…

SQLite 語法大全

SQLite EXPLAIN 語句&#xff1a; EXPLAIN INSERT statement...; or EXPLAIN QUERY PLAN SELECT statement...; SQLite GLOB 子句&#xff1a; SELECT column1, column2....columnN FROM table_name WHERE column_name GLOB { PATTERN }; SQLite GROUP BY 子句&#xff1…

journalctl參數詳解

journalctl 是 Systemd 日志管理工具&#xff0c;用于查看、查詢和管理 Systemd 系統日志。 #-x: 詳細模式&#xff08;Verbose&#xff09;。這個選項會使 journalctl 輸出完整的日志消息&#xff0c;包括其原始結構&#xff0c;如嵌套的JSON消息、未展開的環境變量等。這對于…