CLIP模型實現中的其他細節

之前已經完整的拆解了CLIP中所用到的ResNet、ViT和Transformer三個模型(CLIP拆解-CSDN博客),這篇將講解model.py實現中的其他細節。

1.關于ResNet模型中vision_head的設置

ResNet:

vision_heads = vision_width * 32 // 64

ViT:

vision_heads = vision_width // 64

ResNet需要乘32是因為經過前面卷積處理后輸入AttentionPool2d的是width*32,所以計算head的時候要把這個考慮進去。至于這里的64是分為多頭后每一個頭的embed的通道數,ResNet通常取64,ViT-B常取768

2.關于conver_weights

convert_weights() 是為了節省顯存、提高推理速度,將模型中適合的權重轉換為 fp16。

(1)half()的作用 就是把fp32轉為fp16,如果輸入本身是 fp16,那將不進行任何處理。

(2)一些結構不建議轉化為fp16,因為轉化后會不穩定,所以選擇性的處理

    def _convert_weights_to_fp16(l):if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):l.weight.data = l.weight.data.half()if l.bias is not None:l.bias.data = l.bias.data.half()if isinstance(l, nn.MultiheadAttention):for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:tensor = getattr(l, attr)if tensor is not None:tensor.data = tensor.data.half()for name in ["text_projection", "proj"]:if hasattr(l, name):attr = getattr(l, name)if attr is not None:attr.data = attr.data.half()

下面是常見的不建議使用fp16的模塊:

模塊/操作原因說明
LayerNorm / BatchNorm均值/方差運算容易數值下溢,精度敏感
Softmax / LogSoftmax輸出接近 0 或 1,fp16 下舍入誤差大
Sigmoid / Tanh對小輸入不敏感,精度損失后容易失效
CrossEntropyLoss包含 log(softmax),fp16 精度不足導致數值不穩定
Attention(部分實現)scaled dot-product 會導致爆炸,尤其是大輸入或長序列時
Exp, Div, Log本身不穩定,數值小容易下溢出為 0

3.模型輸入也要相應的進行轉化,否則會遇到類型不匹配的問題

?解決方法1:使用autocast

from torch.cuda.amp import autocastwith autocast():output = model(x)  # 自動在每一層內部管理精度轉換

但autocast只針對模塊的外部類型來判斷是否進行類型轉化(如nn.Linear, nn.Conv2d),但是自定義的模塊(類)autocast不會進行類型轉換(autocast只是解決了類型不匹配的問題,但是低精度產生的梯度爆炸等問題無法解決,由反向傳播時gradscaler解決)

問題場景AMP 是否能處理說明
輸入是 fp16,模塊需要 fp32? autocast() 會自動轉換
自定義模塊內部 +,/ 導致類型錯? 你要自己管理,AMP 不管你自寫的算子
梯度為 0 或爆炸? GradScaler() 自動放大/還原
權重混用不同精度? 支持
推理時類型優化(加速,混用不同精度)? 只用 autocast() 即可

解決方法2:手動轉化類型

# 例如 LayerNorm 中人為轉 float32:
def forward(self, x):orig_type = x.dtyperet = super().forward(x.float())  # 保證 LayerNorm 在 float32 下執行return ret.to(orig_type)

4.關于forward的輸出

# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

logit_scale是縮放因子,定義是self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

logits_per_image是圖像視角下的相似度分布,用于計算圖像到文本的對比損失

logits_per_text是文本視角下的相似度分布,和圖像視角下對稱。

5.關于權重初始化

(1)ResNet的bn3初始化為0

for resnet_block in [self.visual.layer1,self.visual.layer2,self.visual.layer3,self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)

手動初始化bn3.weight為0確保為恒等映射,從而防止殘差支路輸出不穩定、擾動太大的問題。

(2)CLIP中的手動初始化和自動初始化

CLIP只手動初始化了一些對訓練穩定性或性能影響較大的模塊,如embedding和位置編碼(nanoGPT中也對這兩個部分進行了手動初始化)、QKVC投影、transformer最后輸出的初始化

    def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

***與nanoGPT的_init_weights對比

    # mainself.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))#init_weightdef _init_weights(self, module):if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

*GPT
GPT 結構對初始化非常敏感,GPT 使用殘差連接 + LayerNorm,梯度傳播對初始權重分布非常依賴。所以在初始化的時候Linear和Embedding的weight的mean都初始化為0

*CLIP

CLIP更復雜,只初始化關鍵敏感部件,如embedding、positional encoding、attention等。

***目前總結到的經驗

*建議手動初始化:

模塊類型初始化建議原因
Embedding手動正態初始化(如 std=0.01~0.02)防止稀疏索引導致偏置
Q/K/V Linear手動初始化(如 std=1/√d_k防止 attention dot-product 初始值爆炸
Positional Embedding正態初始化因為是 learnable 參數,數值不宜過大
殘差 block 最后一層(如 BN3)初始化為 0初始退化為恒等映射,提高收斂性
任何“關鍵分支”的 projection 層建議初始化如 CLIP 的 text_projection, image_projection

?一般不主動初始化:

模塊類型

理由
Conv2d, Linear默認初始化已很好,除非有論文要求
LayerNorm, BatchNorm默認 weight=1, bias=0 是最優策略
非殘差中的普通線性層默認即可

(3)初始化時std的設置

①?attn_std = self.transformer.width ** -0.5

標準的transformer初始化方法

②fc_std = (2 * self.transformer.width) ** -0.5

用于初始化FFN中的前向Linear層,第一層輸出通道很大(通常是 4×),為了避免輸出激活過大,std 要適當減小

x → Linear(width, 4*width) → GELU → Linear(4*width, width)

③proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)

用于 Residual AttentionBlock 最后投影的 Linear 層

來源:來自論文 Understanding the Difficulty of Training Transformers,特別適用于 深層 Transformer(如 GPT-3, CLIP)

核心思想是:

如果模型深度是 L 層,那每個 residual branch 疊加的方差也會增加,應該將其 std 縮小為 1/sqrt(2L)以穩定整體輸出。

?6.關于build_model的參數的使用

(1)

vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]

這里使用visual.conv1.weight的第一個維度的大小作為width,conv2d的weight的形狀是(out_channle, in_channel, patch_size[0], patch_size[1])。

另外這里補充一下ViT patch和傳統CNN卷積核的區別:
傳統CNN是使用多個小卷積堆疊構建大感受野(kernel_size較小,stride小于kernel_size允許重疊),而ViT是使用一個大kernel,把整塊patch當作token(kernel_size較大,stride=kernel_size,即不重復采樣)

(2)

vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])

每個 Transformer block 里會有一個 nn.MultiheadAttention 模塊,對應權重名如:visual.transformer.resblocks.0.attn.in_proj_weight

(3)

grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size

這里image_resolution是因為

self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))

(4)幾個易混淆的概念

名字意義舉例值類似于
vision_width通道維度64、128、256 等CNN 中的輸出 channels
output_width特征圖尺寸7、14 等feature map 的寬度
patch_sizepatch 的邊長32ViT 中的切片大小‘

(5)ResNet中image_resolution = output_width * 32

*32是因為在ResNet中總共下采樣了5次

模塊操作類型輸出尺寸
conv1stride=2變成 H/2 × W/2
stem_poolAvgPool2d(2)變成 H/4 × W/4
layer1無下采樣尺寸不變
layer2stride=2變成 H/8 × W/8
layer3stride=2變成 H/16 × W/16
layer4stride=2變成 H/32 × W/32 ? 最終輸出
attnpool空間尺寸 = H/32 × W/32

?(6)刪除state_dict中的一些輔助信息字段

    for key in ["input_resolution", "context_length", "vocab_size"]:if key in state_dict:del state_dict[key]

這些不是模型參數的一部分,加載模型權重前必須刪掉,否則會引起state_dict鍵不匹配

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/89630.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/89630.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/89630.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

國科大深度學習作業1-手寫數字識別實驗

背景介紹:單位實習,趁機摸魚,由于電腦只安裝了VSCode,所以算是從環境搭建寫起。 目錄 一、環境搭建 1. 安裝Anaconda 2. 創建Python環境 3. 安裝PyTorch 4. 安裝其他必要庫 二、在 VSCode 中配置環境 1. 安裝Pytho…

基于Spring Boot的綠園社區團購系統的設計與實現

第1章 摘 要 本設計與實現的基于Spring Boot的綠園社區團購系統,旨在為社區居民提供一套高效、便捷的團購購物解決方案。隨著電子商務的發展和社區居民對便捷購物需求的增加,傳統的團購模式已無法滿足用戶的個性化需求。本系統通過整合現代化技術&…

【51單片機四位數碼管從0循環顯示到99,每0.5秒增加一個數字,打擊鍵計數】2022-6-11

緣由 #include "REG52.h" unsigned char code smgduan[]{0x3f,0x06,0x5b,0x4f,0x66,0x6d,0x7d,0x07,0x7f,0x6f,0x77,0x7c,0x39,0x5e,0x79,0x71,0,64,15,56}; //共陰0~F消隱減號 unsigned char Js0, miao0;//中斷計時 秒 分 時 毫秒 unsigned int shu0; //bit Mb0;//…

如何通過python腳本向redis和mongoDB傳點位數據

向MongoDB傳數據 from pymongo import MongoClient #導入庫對應的庫localhost "172.16.0.203" #數據庫IP地址 baseName "GreenNagoya" client MongoClient(localhost, 27017, username"admin", password"zdiai123") #數…

昆侖通泰觸摸屏Modbus TCP服務器工程 || TCP客戶端工程

目錄 一、Modbus TCP服務端 1.設備地址 2.實操及數據 二、Modbus TCP客戶端 1.結果及協議解析 一、Modbus TCP服務端 1.設備地址 --單元標識符 DI輸入/4個離散輸入 DO輸出/單個線圈輸出 輸入寄存器 讀輸入寄存器操作,寫輸入寄存器操作 保持寄存器 …

PyTorch 安裝使用教程

一、PyTorch 簡介 PyTorch 是由 Facebook AI Research 團隊開發的開源深度學習框架。它以動態圖機制、靈活性強、易于調試而著稱,廣泛應用于自然語言處理、計算機視覺和學術研究。 二、安裝 PyTorch 2.1 通過官網選擇安裝命令(推薦) 訪問官…

開源功能開關(feature flags) 和管理平臺之unleash

文章目錄 背景Flagsmith 和 Unleash什么是unleash架構Unleash Edge 安裝和使用Unleash SDKs開放API Tokens訪問**Server-side SDK (CLIENT)****查詢所有 Feature Toggles****查詢特定 Toggle** API token typesClient tokensFrontend tokensPersonal access tokensService acco…

細胞建模“圖靈測試”:解析學習虛擬細胞挑戰賽

一、AI能否預測細胞的未來? 想象一下,有一天我們不必一管管地做實驗,就能在計算機中模擬細胞對基因敲除、藥物處理乃至微環境變化的反應。這不再是科幻,而是“虛擬細胞”(Virtual Cell)研究的宏大目標。然…

centos9安裝docker Dify

CentOS | Docker Docs yum -y install gcc gcc-c yum-utils Docker 官方的 YUM 軟件倉庫配置文件到系統,設置存儲庫 yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 也可以從阿里云下(我選擇上面的) yum-config-manager --add-re…

基于Jenkins和Kubernetes構建DevOps自動化運維管理平臺

目錄 引言 基礎概念 DevOps概述 Jenkins簡介 Kubernetes簡介 Jenkins與Kubernetes的關系 Jenkins與Kubernetes的集成 集成架構 安裝和配置 安裝Jenkins 安裝Kubernetes插件 配置Kubernetes連接 配置Jenkins Agent Jenkins Pipeline與Kubernetes集成 Pipeline定義…

MySQL 8.0 OCP 1Z0-908 題目解析(18)

題目69 Choose three. A MySQL server is monitored using MySQL Enterprise Monitor’s agentless installation. Which three features are available with this installation method? □ A) MySQL Replication monitoring □ B) security-related advisor warnings □ …

【mongodb】安裝和使用mongod

文章目錄 前言一、如何安裝?二、使用步驟1. 開啟mongod服務2. 客戶端連接數據庫3. 數據庫指令 總結 前言 Mongodb的安裝可以直接安裝系統默認的版本,也可以安裝官網維護的版本,相對而言更推薦安裝官網維護的版本,版本也相當更新。…

云效DevOps vs Gitee vs 自建GitLab的技術選型

針對「云效DevOps vs Gitee vs 自建GitLab」的技術選型,我們從核心需求、成本、運維、擴展性四個維度進行深度對比,并給出場景化決策建議: 一、核心能力對比表 能力維度云效DevOpsGitee自建GitLab(社區版/企業版)代碼…

CentOS 7 安裝RabbitMQ詳細教程

前言:在分布式系統架構中,消息隊列作為數據流轉的 “高速公路”,是微服務架構不可或缺的核心組件。RabbitMQ 憑借其穩定的性能、靈活的路由機制和強大的生態支持,成為企業級消息中間件的首選之一。不過,當我們聚焦 Cen…

Python爬蟲用途和介紹

目錄 什么是Python爬蟲 Python爬蟲用途 Python爬蟲可以獲得那些數據 Python爬蟲的用途 反爬是什么 常見的反爬措施 Python爬蟲技術模塊總結 獲取網站的原始響應數據 獲取到響應數據對響應數據進行過濾 對收集好的數據進行存儲 抵御反爬機制 Python爬蟲框架 Python…

uni-app開發app保持登錄狀態

在 uni-app 中實現用戶登錄一次后在 token 過期前一直免登錄的功能,可以通過以下幾個關鍵步驟實現:本地持久化存儲 Token、使用請求與響應攔截器自動處理 Token 刷新、以及在 App.vue 中結合 pages.json 設置登錄狀態跳轉邏輯。 ? 一、pages.json 配置說…

21、MQ常見問題梳理

目錄 ? 、MQ如何保證消息不丟失 1 、哪些環節可能會丟消息 2 、?產者發送消息如何保證不丟失 2.1、?產者發送消息確認機制 2.2、Rocket MQ的事務消息機制 2.3 、Broker寫?數據如何保證不丟失 2.3.1** ?先需要理解操作系統是如何把消息寫?到磁盤的**。 2.3.2然后來…

MySQL數據庫--SQL DDL語句

SQL--DDL語句 1,DDL-數據庫操作2,DDL-表操作-查詢3,DDL-表操作-創建4,DDL-表操作-數據類型4.1,DDL-表操作-數值類型4.2,DDL-表操作-字符串類型4.3,DDL-表操作-日期時間類型4.4,實例 …

Spring Cloud 服務追蹤實戰:使用 Zipkin 構建分布式鏈路追蹤

Spring Cloud 服務追蹤實戰:使用 Zipkin 構建分布式鏈路追蹤 在分布式微服務架構中,一個用戶請求往往需要經過多個服務協作完成,如果出現性能瓶頸或異常,排查會非常困難。此時,分布式鏈路追蹤(Distributed…

Linux云計算基礎篇(6)

一、IO重定向和管道 stdin:standard input 標準輸入 stdout:standard output 標準輸出 stderr: standard error 標準錯誤輸出 舉例 find /etc/ -name passwd > find.out 將正確的輸出重定向在這個find.ou…