前言
最近,開源了可商用的llama2,支持長度相比llama1的1024,拓展到了4096長度,然而,相比GPT-4、Claude-2等支持的長度,llama的長度外推顯得尤為重要,本文記錄了三種網絡開源的RoPE改進方式及相關源碼的閱讀。
關于長度外推性:https://kexue.fm/archives/9431
關于RoPE:https://kexue.fm/archives/8265
1、線性插值法
論文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION
鏈接:https://arxiv.org/pdf/2306.15595.pdf
思想:不進行長度外推,而是直接縮小位置索引。即:將4096的位置編碼通過線性插值法壓縮到2048內,這樣只需在少量的4096長度的數據上繼續預訓練,便可達到不錯的效果。
源碼閱讀(附注釋):
class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale參數self.scale = scale# inv_freq為基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 構建max_seq_len_cached大小的張量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 張量t歸一化,RoPE沒有這一步t /= self.scale# einsum計算頻率矩陣# 'i, j->i j’表示分別輸入尺寸為[i]、[j]的向量,做笛卡爾運算得到尺寸為[i, j]的矩陣。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1維做一次拷貝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注冊為模型的緩沖區cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len為序列長度,seq_len大于max_seq_len_cached,則重新計算頻率矩陣,并更新cos_cached和sin_cached的緩沖區if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 長度裁剪:返回cos_cached和sin_cached中與seq_len(序列長度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
線性插值法的相關實驗效果:https://lmsys.org/blog/2023-06-29-longchat/
2、NTK插值法
NTK插值改進llama中使用的RoPE插值方法,同樣,對于RoPE代碼改動更小,其他地方與線性插值法實現一致。
reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation
鏈接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346
源碼閱讀:
class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 與線性插值法相比,實現更簡單,alpha僅用來改變basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
3、動態插值法
動態插值法又是對NTK插值法和線性插值法的改進,可以看作是上述兩者的一種結合思想,旨在減少困惑度損失并實現更大的縮放。
reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning
鏈接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
源碼閱讀:
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否開啟NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq為基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 計算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 縮放t *= self.max_position_embeddings / seq_len# 得到新的頻率矩陣freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs與自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注冊為模型的緩沖區cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 長度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)
網友對于困惑度的實驗并取得了一定的效果:https://github.com/turboderp/exllama/pull/118
總結
本文介紹了llama通過線性插值法及相關改進方案進行長度外推的trcik,并對相關源碼閱讀及網絡資源進行記錄,個人粗淺認為,相比LongLLaMA,基于線性插值法+Finetune的方式,是一種高性價比的長度外推方案。