LLaMA長度外推高性價比trick:線性插值法及相關改進源碼閱讀及相關記錄

前言

最近,開源了可商用的llama2,支持長度相比llama1的1024,拓展到了4096長度,然而,相比GPT-4、Claude-2等支持的長度,llama的長度外推顯得尤為重要,本文記錄了三種網絡開源的RoPE改進方式及相關源碼的閱讀。

關于長度外推性:https://kexue.fm/archives/9431

關于RoPE:https://kexue.fm/archives/8265

1、線性插值法

論文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

鏈接:https://arxiv.org/pdf/2306.15595.pdf

思想:不進行長度外推,而是直接縮小位置索引。即:將4096的位置編碼通過線性插值法壓縮到2048內,這樣只需在少量的4096長度的數據上繼續預訓練,便可達到不錯的效果。

在這里插入圖片描述

源碼閱讀(附注釋)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale參數self.scale = scale# inv_freq為基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 構建max_seq_len_cached大小的張量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 張量t歸一化,RoPE沒有這一步t /= self.scale# einsum計算頻率矩陣# 'i, j->i j’表示分別輸入尺寸為[i]、[j]的向量,做笛卡爾運算得到尺寸為[i, j]的矩陣。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1維做一次拷貝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注冊為模型的緩沖區cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len為序列長度,seq_len大于max_seq_len_cached,則重新計算頻率矩陣,并更新cos_cached和sin_cached的緩沖區if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 長度裁剪:返回cos_cached和sin_cached中與seq_len(序列長度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

線性插值法的相關實驗效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改進llama中使用的RoPE插值方法,同樣,對于RoPE代碼改動更小,其他地方與線性插值法實現一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

鏈接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源碼閱讀:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 與線性插值法相比,實現更簡單,alpha僅用來改變basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

3、動態插值法

動態插值法又是對NTK插值法和線性插值法的改進,可以看作是上述兩者的一種結合思想,旨在減少困惑度損失并實現更大的縮放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

鏈接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源碼閱讀

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否開啟NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq為基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 計算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 縮放t *= self.max_position_embeddings / seq_len# 得到新的頻率矩陣freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs與自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注冊為模型的緩沖區cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 長度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

網友對于困惑度的實驗并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

總結

本文介紹了llama通過線性插值法及相關改進方案進行長度外推的trcik,并對相關源碼閱讀及網絡資源進行記錄,個人粗淺認為,相比LongLLaMA,基于線性插值法+Finetune的方式,是一種高性價比的長度外推方案。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/38287.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/38287.shtml
英文地址,請注明出處:http://en.pswp.cn/news/38287.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue-打印組件頁面

場景: 需要將頁面的局部信息打印出來&#xff0c;只在前端實現&#xff0c;不要占用后端的資源。經過百度經驗&#xff0c;決定使用 print-js和html2canvas組件。 1. 下載包 npm install print-js --save npm install --save html2canvas 2. 組件內引用 <script>impo…

C語言之數組指針和指針數組

C語言之數組指針和指針數組 一、含義二、定義2.1 指針數組2.2 數組指針 三、使用3.1 指針數組在參數傳遞時的使用3.1.1 指針數組的排序3.2 數組指針在參數傳遞時的使用 一、含義 指針數組&#xff1a;顧名思義&#xff0c;其為一個數組&#xff0c;數組里面存放著多個指針&…

C#生成隨機驗證碼

以下是一個簡單的C#驗證碼示例&#xff1a; private void GenerateCaptcha() {// 生成隨機字符串string chars "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";Random random new Random();string captchaString new string(Enumerable.Repe…

TPAMI, 2023 | 用壓縮隱逆向神經網絡進行高精度稀疏雷達成像

CoIR: Compressive Implicit Radar | IEEE TPAMI, 2023 | 用壓縮隱逆向神經網絡進行高精度稀疏雷達成像 注1:本文系“無線感知論文速遞”系列之一,致力于簡潔清晰完整地介紹、解讀無線感知領域最新的頂會/頂刊論文(包括但不限于Nature/Science及其子刊;MobiCom, Sigcom, MobiSy…

Java【算法 04】HTTP的認證方式之DIGEST認證詳細流程說明及舉例

HTTP的認證方式之DIGEST 1.是什么2.認值流程2.1 客戶端發送請求2.2 服務器返回質詢信息2.2.1 質詢參數2.2.2 質詢舉例 2.3 客戶端生成響應2.4 服務器驗證響應2.5 服務器返回響應 3.算法3.1 SHA-2563.1.1 Response3.1.2 A13.1.3 A2 3.2 MD53.2.1 Request-Digest3.2.2 A13.2.3 A2…

CSS3 中新增了哪些常見的特性?

聚沙成塔每天進步一點點 ? 專欄簡介? 圓角&#xff08;Border Radius&#xff09;? 漸變&#xff08;Gradients&#xff09;? 陰影&#xff08;Box Shadow&#xff09;? 文本陰影&#xff08;Text Shadow&#xff09;? 透明度&#xff08;Opacity&#xff09;? 過渡&…

Spring boot與Spring cloud 之間的關系

Spring boot與Spring cloud 之間的關系 Spring boot 是 Spring 的一套快速配置腳手架&#xff0c;可以基于spring boot 快速開發單個微服務&#xff0c;Spring Boot&#xff0c;看名字就知道是Spring的引導&#xff0c;就是用于啟動Spring的&#xff0c;使得Spring的學習和使用…

MATLAB中xlsread函數用法

目錄 語法 說明 示例 將工作表讀取到數值矩陣 讀取元胞的范圍 讀取列 請求數值、文本和原始數據 對工作表執行函數 請求自定義輸出 局限性 xlsread函數的功能是讀取Microsoft Excel 電子表格文件 語法 num xlsread(filename) num xlsread(filename,sheet) num x…

Nacos和GateWay路由轉發NotFoundException: 503 SERVICE_UNAVAILABLE “Unable to find

問題再現&#xff1a; 2023-08-15 16:51:16,151 DEBUG [reactor-http-nio-2][CompositeLog.java:147] - [dc73b32c-1] Encoding [{timestampTue Aug 15 16:51:16 CST 2023, path/content/course/list, status503, errorService Unavai (truncated)...] 2023-08-15 16:51:16,17…

leetcode27—移除元素

思路&#xff1a; 參考26題目雙指針的思想&#xff0c;只不過這道題不是快慢指針。 看到示例里面數組是無序的&#xff0c;也就是說后面的元素也是可能跟給定 val值相等的&#xff0c;那么怎么處理呢。就想到了從前往后遍歷&#xff0c;如果left對應的元素 val時&#xff0c…

汽車制造業上下游協作時 外發數據如何防泄露?

數據文件是制造業企業的核心競爭力&#xff0c;一旦發生數據外泄&#xff0c;就會給企業造成經濟損失&#xff0c;嚴重的&#xff0c;可能會帶來知識產權剽竊損害、名譽傷害等。汽車制造業&#xff0c;會涉及到重要的汽車設計圖紙&#xff0c;像小米發送汽車設計圖紙外泄事件并…

[足式機器人]Part5 機械設計 Ch00/01 緒論+機器結構組成與連接 ——【課程筆記】

本文僅供學習使用 本文參考&#xff1a; 《機械設計》 王德倫 馬雅麗課件與日常作業可登錄網址 http://edu.bell-lab.com/manage/#/login&#xff0c;選擇觀摩登錄&#xff0c;查看2023機械設計2。 機械設計-Ch00Ch01——緒論機器結構組成與連接 Ch00-緒論0.1 何為機械設計——…

12.Eclipse導入Javaweb項目

同事復制一份他的項目給我ekp.rar (懶得從SVN上拉取代碼了)放在workspace1目錄下 新建一個文件夾 workspace2&#xff0c;Eclipse切換到workspace2工作空間 選擇Import導入 選擇導入的項目(這里是放到workspace1里面) 拷貝一份到workspace2里面 例子 所有不是在自己電腦上開發…

可白嫖的4家免費CDN,并測試其網絡加速情況(2023版)

網站加載速度優化過程中&#xff0c;不可避免的會用上CDN來加速資源的請求速度。但是市面上的CDN資源幾乎都是要收費的&#xff0c;而且價格還不便宜&#xff0c;對于小公司站長來講&#xff0c;這將是一筆不小的開銷。不過還是有一些良心公司給我們提供了免費的資源&#xff0…

ZooKeeper的基本概念

集群角色 通常在分布式系統中&#xff0c;構成一個集群的每一臺機器都有自己的角色&#xff0c;最典型的集群模式就是Master/Slave模式(主備模式)。在這種模式中&#xff0c;我們把能夠處理所有寫操作的機器稱為Master機器&#xff0c;把所有通過異步復制方式獲取最新數據&…

Redis_億級訪問量數據處理

11. 億級訪問量數據處理 11.1 場景表述 手機APP用戶登錄信息&#xff0c;一天用戶登錄ID或設備ID電商或者美團平臺&#xff0c;一個商品對應的評論文章對應的評論APP上有打卡信息網站上訪問量統計統計新增用戶第二天還留存商品評論的排序月活統計統計獨立訪客(Unique Vistito…

【BEV】3D視覺 PRELIMINARY

這里的知識來自于論文 Delving into the Devils of Bird’s-eye-view Perception: A Review, Evaluation and Recipe 的 Appendix B.1 部分來自 這篇文章 從透視圖轉向鳥瞰圖。&#xff08;Xw、Yw、Zw&#xff09;、&#xff08;Xc、Yc、Zc&#xff09;表示世界World坐標和相…

Android學習之路(4) UI控件之Button (按鈕)與 ImageButton (圖像按鈕)

本節引言&#xff1a; 今天給大家介紹的Android基本控件中的兩個按鈕控件&#xff0c;Button普通按鈕和ImageButton圖像按鈕&#xff1b; 其實ImageButton和Button的用法基本類似&#xff0c;至于與圖片相關的則和后面ImageView相同&#xff0c;所以本節 只對Button進行講解&am…

vue自定義穿梭框支持遠程滾動加載

分享-2023年資深前端進階&#xff1a;前端登頂之巔-最全面的前端知識點梳理總結&#xff0c;前端之巔 *分享一個使用比較久的&#x1fa9c; 技術框架公司的選型(老項目)&#xff1a;vue2 iview-ui 方案的實現思路是共性的&#xff0c;展現UI樣式需要你們自定義進行更改&#…

【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring團隊不推薦使用Field注入)

問題發生場景&#xff1a; 在使用 IDEA 開發 SpringBoot 項目時&#xff0c;在 Controller 類中使用注解 Autowired 注入一個依賴出現了警告提示&#xff0c;查看其他使用該注解的地方同樣出現了警告提示。這是怎么回事&#xff1f;由于先去使用了SpringBoot并沒有對Spring進行…