手搓多模態-08 主模型的搭建(上)

前情回顧

在之前的章節我們已經構建好了視覺編碼器,預處理模塊,以及gemma模型的頂層。gemma模型的頂層,主要是構建圖中圈出的輸入,它把視覺編碼器里每個圖像patch的編碼維度對齊到自然語言token的嵌入維度,并組裝成了一個大的輸入向量。同時在模型的頂層,我們準備好了位置id 以及attention mask,用來在后面的模型層計算旋轉位置編碼和注意力得分矩陣。接下來,我們要開始構建gemma模型的架構了。

頂層模型 GemmaForCausalLM

還記得嗎,在之前的paligemma模型的頂層,我們有一個GemmaForCausalLM,然后我們通過下面的代碼把輸入傳入了語言模型:

self.language_model = GemmaForCausalLM(config.text_config)
		outputs = self.language_model(
			inputs_embeds = input_embeds,
			position_ids = position_ids,
			attention_mask = attention_mask,
			kv_cache = kv_cache,
			**kwargs
		)

現在我們首先要實現這個GemmaForCausalLM。

一般模型的上層是對整個模型邏輯的簡單封裝,故這里GemmaForCausalLM的作用很簡單,它僅僅把上下文編碼后的注意力嵌入通過一個MLP轉換為不同token的輸出概率,也就是logits,然后返回給上層,從而讓上層根據概率分布來采樣下一個要輸出的token是什么。

先給出代碼:

class GemmaForCausalLM(nn.Module): ## 匹配
	def __init__(self,config:GemmaConfig): ##CasualLM實際上是Transformer模型加一個投影層,即將嵌入轉換為對數概率
		super().__init__()
		self.config = config
		self.model = GemmaModel(config)
		self.vocab_size = config.vocab_size
		self.lm_head = nn.Linear(config.hidden_size,config.vocab_size,bias=False)	def get_input_embeddings(self): ##這里返回的是模型對象本身
		return self.model.embed_tokens	def tie_weights(self):
		self.lm_head.weight = self.model.embed_tokens.weight	def forward(
		self,
		attention_mask: Optional[torch.Tensor] = None,
		inputs_embeds: Optional[torch.FloatTensor] = None,
		kv_cache: Optional[KVCache] = None,
		position_ids: Optional[torch.Tensor] = None
	):
		'''
		input: [Batch_size, Seq_len, Hidden_size]
		output: [Batch_size, Seq_len, Vocab_size]
		'''
		## [Batch_size, Seq_len, Hidden_size]
		outputs = self.model(
			attention_mask = attention_mask,
			inputs_embeds = inputs_embeds,
			kv_cache = kv_cache,
			position_ids = position_ids
		)		hidden_states = outputs
		logits = self.lm_head(hidden_states) #lm_head負責將hidden_states映射到vocab_size維度的向量,即logits
		logits = logits.float()		return_data = {
			"logits": logits 
		}
		if kv_cache is not None:
			return_data["kv_cache"] = kv_cache ##這里kv cache是要傳遞下去的,因為自回歸的邏輯下,后面生成的token的注意力計算要能夠通過kv cache來看到之前的token的kv		return return_data

以上便是頂層模型的前向傳遞過程:

  • 就是通過 GemmaModel 生成的注意力嵌入來計算logits
  • 注意:由于我們在推理過程中,后續的token計算要用到之前的kv,所以kv cache必須在推理的過程中依次傳遞下去,同時也要返回給上層,從而在下一次推理運算的時候有kv cache可以傳入。
  • 我們之前用到了參數捆綁的策略,即token嵌入的模型參數等于嵌入反解碼成logits的模型參數,所以我們提供這兩個函數供上層調用:
def get_input_embeddings(self): ##這里返回的是模型對象本身return self.model.embed_tokensdef tie_weights(self):self.lm_head.weight = self.model.embed_tokens.weight

GemmaModel

GemmaModel里面實際上就是一個注意力塊的序列,就像一個注意力塊數組一樣,而該層需要做的僅僅是將輸入在不同的注意力塊里依次傳遞,并把最后一個注意力塊的輸出返回給上層即可。

class GemmaModel(nn.Module): ## 匹配def __init__(self,config:GemmaConfig):super().__init__()
		self.config = config
		self.hidden_size = config.hidden_size
		self.embed_tokens = nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)
		self.layers = nn.ModuleList([GemmaLayer(config, _) for _ in range(config.num_hidden_layers)])
		self.norm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps) ##Root Mean Square Normalization均方根標準化,該論文表明并不一定要標準化到標準正態分布,而是只要方差為1就可以def forward(
		self,
		attention_mask: Optional[torch.Tensor] = None,
		inputs_embeds: Optional[torch.FloatTensor] = None,
		kv_cache: Optional[KVCache] = None,
		position_ids: Optional[torch.Tensor] = None):#[Batch_size, Seq_len, Hidden_size]
		hidden_states = inputs_embeds
		normalizer = torch.tensor(self.hidden_size ** 0.5,dtype= inputs_embeds.dtype)
		hidden_states = hidden_states * normalizerfor layer in self.layers:
			hidden_states = layer(
				hidden_states = hidden_states,
				attention_mask = attention_mask,
				kv_cache = kv_cache,
				position_ids = position_ids)## 均方根歸一化,不改變shape
		hidden_states = self.norm(hidden_states)return hidden_states

這里我們用一個nn.ModuleList來存儲所有的GemmaLayer,一個GemmaLayer實際上就是一個attention 塊。值得注意的是,在每個attention塊內部我們將會做兩次歸一化,但是每個attention layer的輸出不會做歸一化,為了使得上層的計算能拿到歸一化后的結果,我們在整個list前向傳遞完了之后再補一個normalization的過程:

hidden_states = self.norm(hidden_states)
  • 注意:我們此處用的是RMSNorm,即均方根歸一化,關于這個歸一化與之前的其他歸一化的不同我們會在文末補充一些資料。

有人可能想問,為什么嵌入模型會放到這里:self.embed_tokens

這是因為paligemma的作者是這么實現的,而我們將從huggingface來導入整個模型的參數,所以我們的架構也必須和作者一樣才能正確導入參數,所以我們不得不放在這里。

GemmaLayer

在一個attention塊里面我們有一個多頭注意力層和一個前向傳播網絡,以及兩個歸一化,但我們實際的實現中會把歸一化提前,即add&norm -> attention -> add&norm -> ff。

這也就是為什么上面提到在layer的輸出沒有做歸一化。

代碼如下:

class GemmaLayer(nn.Module): ##匹配def __init__(self,config:GemmaConfig,layer_idx:int): ##layer_idx是當前layer的索引,輔助attention存儲kv_cachesuper().__init__()self.config = configself.layer_idx = layer_idxself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.input_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.mlp = GemmaMLP(config)self.self_attn = GemmaAttention(config,layer_idx)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,position_ids: Optional[torch.Tensor] = None)-> Tuple[torch.Tensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
		'''input: [Batch_size, Seq_len, Hidden_size]output: [Batch_size, Seq_len, Hidden_size]		'''residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)hidden_states,_ = self.self_attn(hidden_states = hidden_states,attention_mask = attention_mask,kv_cache = kv_cache,position_ids = position_ids)hidden_states = residual + hidden_statesresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesreturn hidden_states

  • 在這里的兩個歸一化我們也用RMSNorm來進行歸一化,注意除了歸一化,我們還要處理好殘差。
  • 殘差的作用是防止梯度為0導致訓練緩慢。

RMSNorm

在前面的第四章節:手搓多模態-04 歸一化介紹 里面我們介紹了BatchNormalization和LayerNormalization,我們了解到以下信息:

  • 歸一化是為了防止不同模型層的輸入輸出不穩定,分布不均勻導致的訓練速度過慢
  • BN 依賴于batch 的規模,而batch的規模過大會導致訓練速度變相過慢
  • LN 通過對單個樣本的所有特征進行標準化規避了BN的問題,主要做法是對單個樣本的所有特征計算均值和方差,從而將其分布轉換為0-1分布。

RMSNormalization,又稱均方差歸一化,是由論文《Root Mean Square Layer Normalization》提出的,該文章發現,其實分布不穩定的問題和均值沒有關系,主要是方差的問題,所以只需要特征的方差穩定即可,不需要計算均值,這樣可以減少計算的時間,從而加速訓練。

論文提出用均方根來對每個值進行縮放,從而使得方差更小,如圖所示。

其中,a_i 表示縮放前的特征值,RMS(a)表示所有特征值計算出來的均方根,g是一個可學習的參數向量,b是偏置。

在paligemma的實現中,RMSNorm的代碼如下:

class GemmaRMSNorm(nn.Module): ##匹配
	def __init__(self,dim,eps=1e-6): ##dim是hidden_size
		super().__init__()
		self.dim = dim
		self.eps = eps
		self.weight = nn.Parameter(torch.ones(dim))	def _norm(self,x):		return x * torch.rsqrt(x.pow(2).mean(dim = -1,keepdim=True) + self.eps) ##rsqrt表示平方的倒數,self.eps是防止分母為0	def forward(self,x):
		x = self._norm(x)
		output = x * (1.0 + self.weight.float()) ##論文中的可學習參數g
		return output.type_as(x)

其中特征的維度為嵌入的維度大小。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85269.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85269.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85269.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Matlab 角點探測

文章目錄 一、簡介二、實現代碼三、實現效果一、簡介 這里實現一種角點探測功能,其思路仍然是借助圖像的局部梯度信息,實現亞像素精度的角點定位。該功能核心思想是利用角點周圍的局部梯度信息,通過加權最小二乘優化的方式迭代調整角點位置,使定位精度達到亞像素級別。整個…

錯誤監控----比如實現sentry一些思路

錯誤監控 ?、引? 1.為什么需要前端錯誤監控 你的腳本在哪些邊界條件下會報錯? 你的腳本和樣式兼容性如何? 有哪些地區不能正常訪問你的?站? 出現問題之后,你如何快速定位排查,把損失降到最低? 如果你想解…

linux內核調試

1. 前置安裝 1.1 編譯好的內核 參考: https://blog.csdn.net/qq_51950769/article/details/148596916 1.2 編譯busybox BusyBox 是一個非常輕量級的多合一工具箱,常被稱為“Linux 的瑞士軍刀”。 簡單來說: 它把很多常用的 Linux 命令&am…

什么是曲面細分

什么是曲面細分 在CAD格式中,通常使用曲線和數學函數來定義曲面和實體。這些曲面的精確度和光滑度非常適用于制造過程。但是,現代GPU芯片針對由三角形網格體組成的曲面的渲染計算進行了高度優化。通常,實時渲染器和虛幻之類的游戲引擎只能處…

CANFD加速是什么?和CANFD有什么區別?

文章目錄 摘要什么是CANFD加速?CAN FD的基本原理:仲裁階段(Arbitration Phase):數據階段(Data Phase):關鍵特性:優勢:總結摘要 下面的截圖,大家肯定不陌生,在使用CAN設備上位機的時候,已經選擇了CANFD,但還有一個選項是“CANFD加速”,那CANFD加速和不加速有什么…

minio 啟動失敗--Incorrect Usage: flag provided but not defined: -consoleaddress

根據錯誤信息 flag provided but not defined: -consoleaddress,這表明 Minio 服務啟動時使用了未定義的命令行參數 --consoleaddress,導致啟動失敗。這個問題與 Minio 版本兼容性有關。 問題原因 參數名稱變更: Minio 版本 > RELEASE.20…

基于Rust的Polars學習筆記

基于Rust的Polars學習筆記 Polars 學習筆記 Cargo.toml通用配置 [package] name = "rustP" version = "0.1.0" edition = "2024"[dependencies] polars = { version = "0.48.1", features = ["full"]}Quickstart use po…

SpringBoot擴展——定時任務!

定時任務 項目開發中會涉及很多需要定時執行的代碼,如每日凌晨對前一日的數據進行匯總,或者系統緩存的清理、對每日的數據進行分析和總結等需求,這些都是定時任務。單體系統和分布式系統的分布式任務有很大的區別,單體系統就一個…

RTDETRv2 pytorch 官方版自己數據集訓練遇到的問題解決

rtdetrv2 訓練問題遇到的問題。 pip install torch2.0.1 torchvision0.15.2 torchaudio2.0.2 --index-url https://download.pytorch.org/whl/cu117 1 Please make sure torchvision version > 0.15.2 發現自己實際裝的是 torchvison0.15.2cu117 修改_misc.py中修改為…

Linux系統移植⑤:uboot啟動流程詳解-board_init_f執行過程

Linux系統移植⑤:uboot啟動流程詳解-board_init_f執行過程 _main 中會調用 board_init_f 函數。 board_init_f 函數主要有兩個工作: ①初始化一系列外設,比如串口、定時器,或者打印一些消息等。 ②初始化 gd 的各個成員變量&am…

Git命令與代碼倉庫管理

步驟一、完成Gitee碼云上賬號注冊并新建代碼倉庫。 1.1 新建代碼倉庫 1.2 填寫信息并創建 1.3 獲取倉庫地址 https://gitee.com/dog-kidney/2022082206.git 步驟二、建立本地代碼倉庫,并連接到遠程代碼倉庫。 2.1初始化 git init 2.2添加倉庫 git remote add o…

資源占用多,Linux 系統中如何降低 CPU 資源消耗并提升利用率?

在 Linux 系統中降低 CPU 資源消耗并提升利用率,需從系統服務優化、進程管理、資源調度及內核參數調整等多維度入手。以下是適用于各類 Linux 發行版的通用優化方案,涵蓋基礎操作與進階策略: 一、服務與進程優化:減少無效資源占用 1. 關閉冗余系統服務 查看運行中的服務 …

技術與情感交織的一生 (八)

目錄 融合 東西廠公 接風宴 頭痛 “巴巴羅薩” 突擊 推進 助攻 96小時 寒冬 食堂 反攻 消耗 Delphi 西廠 內困 外患 “敦刻爾克” 多線作戰 大撤退 資源 融合 東西廠公 初次來到紙箱廠,是主廠區,感覺很大,相對西面正在…

webuploader分片上傳示例,服務端上傳文件到騰訊云CDN Teo 應用示例

本文環境:php7.3.4 CI3.0框架 一、大概步驟: (1)利用百度的webuploader插件,將大文件分片上傳的自己的服務器 (2)利用騰訊云接口從本服務器上傳到騰訊云 二、詳細代碼: 1、進入…

LeetCode 632.最小區間

你有 k 個 非遞減排列 的整數列表。找到一個 最小 區間&#xff0c;使得 k 個列表中的每個列表至少有一個數包含在其中。 我們定義如果 b-a < d-c 或者在 b-a d-c 時 a < c&#xff0c;則區間 [a,b] 比 [c,d] 小。 示例 1&#xff1a; 輸入&#xff1a;nums [[4,10,…

篇章五 系統性能優化——資源優化——CPU優化(2)

目錄 1.高級并發模式 1.1 工作竊取&#xff08;Work Stealing&#xff09; 1.工作竊取模式 2.ForkJoinPool實現 3.具體例子 1.2 結構化并發&#xff08;Structured Concurrency&#xff09; 1.結構化并發模式 2.Java 19 的 StructuredTaskScope 3.具體例子 1.3 對比與…

《中國電信運營商骨干網:歷史、現狀與未來演進》系列 第四篇:后發先至——中國移動CMNET的快速擴張與IP專網布局

摘要&#xff1a; 本文深入探討中國移動骨干網CMNET (AS9808) 的發展歷程、網絡架構及其與中國電信扁平化策略的差異。同時&#xff0c;解析其為承載高價值業務而構建的IP專用承載網的定位、結構與技術特點。最后&#xff0c;展望中國移動在5G、云計算和算力網絡時代&#xff0…

R情感分析:解碼文本中的情感

基于之前關于文本聚類和文本模型的博客&#xff0c;我們現在可以深入探討一個經典主題 - 情感分析。情感分析通過計算方式識別和分類文本中的情感&#xff0c;幫助理解公眾意見或消費者反饋。 什么是情感分析&#xff1f; 情感分析確定文本背后的情感基調&#xff0c;將其分類…

云徙渠道訂貨系統:賦能企業渠道管理的數字化引擎

在當今商業競爭日益激烈的環境下&#xff0c;企業如何高效管理和優化渠道成為關鍵問題。云徙渠道訂貨系統憑借其強大的數字化能力&#xff0c;為企業提供了全新的渠道管理解決方案&#xff0c;助力企業在復雜多變的市場環境中保持競爭力。 從渠道管理的痛點出發 傳統渠道管理方…

Nacos基礎使用(二):nacos作為配置中心

一、Nacos 配置中心核心屬性 在學習nacos 作為配置中心的使用之前&#xff0c;先看下Nacos 作為配置中心時的三個屬性&#xff0c;即&#xff1a; 命名空間、配置分組、配置集ID&#xff08;習慣稱為配置文件ID&#xff09;&#xff1b;在使用Nacos 作為配置中心 的過程中可以通…