之前已經完整的拆解了CLIP中所用到的ResNet、ViT和Transformer三個模型(CLIP拆解-CSDN博客),這篇將講解model.py實現中的其他細節。
1.關于ResNet模型中vision_head的設置
ResNet:
vision_heads = vision_width * 32 // 64
ViT:
vision_heads = vision_width // 64
ResNet需要乘32是因為經過前面卷積處理后輸入AttentionPool2d的是width*32,所以計算head的時候要把這個考慮進去。至于這里的64是分為多頭后每一個頭的embed的通道數,ResNet通常取64,ViT-B常取768
2.關于conver_weights
convert_weights()
是為了節省顯存、提高推理速度,將模型中適合的權重轉換為 fp16。
(1)half()的作用 就是把fp32轉為fp16,如果輸入本身是 fp16,那將不進行任何處理。
(2)一些結構不建議轉化為fp16,因為轉化后會不穩定,所以選擇性的處理
def _convert_weights_to_fp16(l):if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):l.weight.data = l.weight.data.half()if l.bias is not None:l.bias.data = l.bias.data.half()if isinstance(l, nn.MultiheadAttention):for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:tensor = getattr(l, attr)if tensor is not None:tensor.data = tensor.data.half()for name in ["text_projection", "proj"]:if hasattr(l, name):attr = getattr(l, name)if attr is not None:attr.data = attr.data.half()
下面是常見的不建議使用fp16的模塊:
模塊/操作 | 原因說明 |
---|---|
LayerNorm / BatchNorm | 均值/方差運算容易數值下溢,精度敏感 |
Softmax / LogSoftmax | 輸出接近 0 或 1,fp16 下舍入誤差大 |
Sigmoid / Tanh | 對小輸入不敏感,精度損失后容易失效 |
CrossEntropyLoss | 包含 log(softmax) ,fp16 精度不足導致數值不穩定 |
Attention (部分實現) | scaled dot-product 會導致爆炸,尤其是大輸入或長序列時 |
Exp , Div , Log | 本身不穩定,數值小容易下溢出為 0 |
3.模型輸入也要相應的進行轉化,否則會遇到類型不匹配的問題
?解決方法1:使用autocast
from torch.cuda.amp import autocastwith autocast():output = model(x) # 自動在每一層內部管理精度轉換
但autocast只針對模塊的外部類型來判斷是否進行類型轉化(如nn.Linear, nn.Conv2d),但是自定義的模塊(類)autocast不會進行類型轉換(autocast只是解決了類型不匹配的問題,但是低精度產生的梯度爆炸等問題無法解決,由反向傳播時gradscaler解決)
問題場景 | AMP 是否能處理 | 說明 |
---|---|---|
輸入是 fp16,模塊需要 fp32 | ? autocast() 會自動轉換 | |
自定義模塊內部 + ,/ 導致類型錯 | ? 你要自己管理,AMP 不管你自寫的算子 | |
梯度為 0 或爆炸 | ? GradScaler() 自動放大/還原 | |
權重混用不同精度 | ? 支持 | |
推理時類型優化(加速,混用不同精度) | ? 只用 autocast() 即可 |
解決方法2:手動轉化類型
# 例如 LayerNorm 中人為轉 float32:
def forward(self, x):orig_type = x.dtyperet = super().forward(x.float()) # 保證 LayerNorm 在 float32 下執行return ret.to(orig_type)
4.關于forward的輸出
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
logit_scale是縮放因子,定義是self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logits_per_image是圖像視角下的相似度分布,用于計算圖像到文本的對比損失
logits_per_text是文本視角下的相似度分布,和圖像視角下對稱。
5.關于權重初始化
(1)ResNet的bn3初始化為0
for resnet_block in [self.visual.layer1,self.visual.layer2,self.visual.layer3,self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)
手動初始化bn3.weight為0確保為恒等映射,從而防止殘差支路輸出不穩定、擾動太大的問題。
(2)CLIP中的手動初始化和自動初始化
CLIP只手動初始化了一些對訓練穩定性或性能影響較大的模塊,如embedding和位置編碼(nanoGPT中也對這兩個部分進行了手動初始化)、QKVC投影、transformer最后輸出的初始化
def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
***與nanoGPT的_init_weights對比
# mainself.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))#init_weightdef _init_weights(self, module):if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
*GPT
GPT 結構對初始化非常敏感,GPT 使用殘差連接 + LayerNorm,梯度傳播對初始權重分布非常依賴。所以在初始化的時候Linear和Embedding的weight的mean都初始化為0
*CLIP
CLIP更復雜,只初始化關鍵敏感部件,如embedding、positional encoding、attention等。
***目前總結到的經驗
*建議手動初始化:
模塊類型 | 初始化建議 | 原因 |
---|---|---|
Embedding | 手動正態初始化(如 std=0.01~0.02) | 防止稀疏索引導致偏置 |
Q/K/V Linear | 手動初始化(如 std=1/√d_k ) | 防止 attention dot-product 初始值爆炸 |
Positional Embedding | 正態初始化 | 因為是 learnable 參數,數值不宜過大 |
殘差 block 最后一層(如 BN3) | 初始化為 0 | 初始退化為恒等映射,提高收斂性 |
任何“關鍵分支”的 projection 層 | 建議初始化 | 如 CLIP 的 text_projection , image_projection |
?一般不主動初始化:
模塊類型 | 理由 |
---|---|
Conv2d , Linear | 默認初始化已很好,除非有論文要求 |
LayerNorm , BatchNorm | 默認 weight=1 , bias=0 是最優策略 |
非殘差中的普通線性層 | 默認即可 |
(3)初始化時std的設置
①?attn_std = self.transformer.width ** -0.5
標準的transformer初始化方法
②fc_std = (2 * self.transformer.width) ** -0.5
用于初始化FFN中的前向Linear層,第一層輸出通道很大(通常是 4×),為了避免輸出激活過大,std 要適當減小。
x → Linear(width, 4*width) → GELU → Linear(4*width, width)
③proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
用于 Residual AttentionBlock 最后投影的 Linear 層
來源:來自論文 Understanding the Difficulty of Training Transformers,特別適用于 深層 Transformer(如 GPT-3, CLIP)。
核心思想是:
如果模型深度是 L 層,那每個 residual branch 疊加的方差也會增加,應該將其 std 縮小為 1/sqrt(2L)以穩定整體輸出。
?6.關于build_model的參數的使用
(1)
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
這里使用visual.conv1.weight的第一個維度的大小作為width,conv2d的weight的形狀是(out_channle, in_channel, patch_size[0], patch_size[1])。
另外這里補充一下ViT patch和傳統CNN卷積核的區別:
傳統CNN是使用多個小卷積堆疊構建大感受野(kernel_size較小,stride小于kernel_size允許重疊),而ViT是使用一個大kernel,把整塊patch當作token(kernel_size較大,stride=kernel_size,即不重復采樣)
(2)
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
每個 Transformer block 里會有一個 nn.MultiheadAttention
模塊,對應權重名如:visual.transformer.resblocks.0.attn.in_proj_weight
(3)
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
這里image_resolution是因為
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
(4)幾個易混淆的概念
名字 | 意義 | 舉例值 | 類似于 |
---|---|---|---|
vision_width | 通道維度 | 64、128、256 等 | CNN 中的輸出 channels |
output_width | 特征圖尺寸 | 7、14 等 | feature map 的寬度 |
patch_size | patch 的邊長 | 32 | ViT 中的切片大小‘ |
(5)ResNet中image_resolution = output_width * 32
*32是因為在ResNet中總共下采樣了5次
模塊 | 操作類型 | 輸出尺寸 |
---|---|---|
conv1 | stride=2 | 變成 H/2 × W/2 |
stem_pool | AvgPool2d(2) | 變成 H/4 × W/4 |
layer1 | 無下采樣 | 尺寸不變 |
layer2 | stride=2 | 變成 H/8 × W/8 |
layer3 | stride=2 | 變成 H/16 × W/16 |
layer4 | stride=2 | 變成 H/32 × W/32 ? 最終輸出 |
attnpool | 空間尺寸 = H/32 × W/32 |
?(6)刪除state_dict中的一些輔助信息字段
for key in ["input_resolution", "context_length", "vocab_size"]:if key in state_dict:del state_dict[key]
這些不是模型參數的一部分,加載模型權重前必須刪掉,否則會引起state_dict鍵不匹配