DAY 50 預訓練模型+CBAM模塊

@浙大疏錦行https://blog.csdn.net/weixin_45655710

知識點回顧:

  1. resnet結構解析
  2. CBAM放置位置的思考
  3. 針對預訓練模型的訓練策略
    1. 差異化學習率
    2. 三階段微調

作業:

  1. 好好理解下resnet18的模型結構
  2. 嘗試對vgg16+cbam進行微調策略
ResNet-18 結構核心思想

可以將ResNet-18想象成一個高效的“圖像信息處理流水線”,它分為三個核心部分

  1. “開胃菜” - 輸入預處理 (Stem)

    • 組成:一個大的7x7卷積層 (conv1) + 一個最大池化層 (maxpool)。

    • 作用:對輸入的原始大尺寸圖像(如224x224)進行一次快速、大刀闊斧的特征提取和尺寸壓縮。它迅速將圖像尺寸減小到56x56,為后續更精細的處理做好準備,像是一道開胃菜,快速打開味蕾。

  2. “主菜” - 四組殘差塊 (Layer1, 2, 3, 4)

    • 組成:這是ResNet的心臟,由四組Sequential模塊構成,每組里面包含2個BasicBlock(殘差塊)。

    • 作用:這是真正進行深度特征提取的地方。其最精妙的設計在于:

      • 層級遞進:從layer1layer4,特征圖的空間尺寸逐級減半(56→28→14→7),而通道數逐級翻倍(64→128→256→512)。這實現了“犧牲空間細節,換取更高層語義信息”的經典策略。

      • 殘差連接:每個BasicBlock內部的“跳躍連接”(out += identity)是其靈魂。它允許信息和梯度“抄近道”,直接從塊的輸入流向輸出,完美解決了深度網絡中因信息丟失導致的“網絡退化”和梯度消失問題。

  3. “甜點” - 分類頭 (Head)

    • 組成:一個全局平均池化層 (avgpool) + 一個全連接層 (fc)。

    • 作用

      • avgpool:將layer4輸出的512x7x7的復雜特征圖,暴力壓縮成一個512維的特征向量,濃縮了整張圖最高級的語義信息。

      • fc:扮演最終“裁判”的角色,將這個512維的特征向量映射到最終的類別得分上(例如,ImageNet的1000類)。

總結來說,ResNet-18的優雅之處在于其清晰的模塊化設計和革命性的殘差連接,它通過“尺寸減半,通道加倍”的策略逐層加深語義理解,并利用“跳躍連接”保證了信息流的暢通,從而能夠構建出既深又易于訓練的強大網絡。

對VGG16 + CBAM 進行微調

VGG16以其結構統一、簡單(全是3x3卷積和2x2池化)而著稱,但缺點是參數量巨大。我們將為其集成CBAM,并應用類似的分階段微調策略。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time
from tqdm import tqdm# --- 模塊定義 (CBAM 和數據加載器,與之前一致) ---
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False), nn.ReLU(),nn.Linear(in_channels // ratio, in_channels, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.shapeavg_out = self.fc(self.avg_pool(x).view(b, c))max_out = self.fc(self.max_pool(x).view(b, c))attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)return x * attentionclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)pool_out = torch.cat([avg_out, max_out], dim=1)attention = self.conv(pool_out)return x * self.sigmoid(attention)class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):return self.spatial_attn(self.channel_attn(x))def get_cifar10_loaders(batch_size=64, resize_to=224): # VGG需要224x224輸入print(f"--- 正在準備數據 (圖像將縮放至 {resize_to}x{resize_to}) ---")transform = transforms.Compose([transforms.Resize(resize_to),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)print("? 數據加載器準備完成。")return train_loader, test_loader# --- 新增:VGG16 + CBAM 模型定義 ---
class VGG16_CBAM(nn.Module):def __init__(self, num_classes=10, pretrained=True):super().__init__()# 加載預訓練的VGG16的特征提取部分vgg_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None).features# 我們將VGG的特征提取層按池化層分割,并在每個塊后插入CBAMself.features = nn.ModuleList()self.cbam_modules = nn.ModuleList()current_channels = 3vgg_block = []for layer in vgg_features:vgg_block.append(layer)if isinstance(layer, nn.Conv2d):current_channels = layer.out_channelsif isinstance(layer, nn.MaxPool2d):self.features.append(nn.Sequential(*vgg_block))self.cbam_modules.append(CBAM(current_channels))vgg_block = [] # 開始新的塊# VGG的分類器部分self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(),nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(),nn.Linear(4096, num_classes),)def forward(self, x):for feature_block, cbam_module in zip(self.features, self.cbam_modules):x = feature_block(x)x = cbam_module(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# --- 訓練和評估框架 (復用) ---
def run_experiment(model_name, model, device, train_loader, test_loader, epochs):print(f"\n{'='*25} 開始實驗: {model_name} {'='*25}")model.to(device)total_params = sum(p.numel() for p in model.parameters())print(f"模型總參數量: {total_params / 1e6:.2f}M")criterion = nn.CrossEntropyLoss()# 差異化學習率:為不同的部分設置不同的學習率optimizer = optim.Adam([{'params': model.features.parameters(), 'lr': 1e-5}, # 特征提取層使用極低學習率{'params': model.cbam_modules.parameters(), 'lr': 1e-4}, # CBAM模塊使用中等學習率{'params': model.classifier.parameters(), 'lr': 1e-3} # 分類頭使用較高學習率])for epoch in range(1, epochs + 1):model.train()loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}] Training", leave=False)for data, target in loop:data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()loop.set_postfix(loss=loss.item())loop.close()model.eval()test_loss, correct = 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item() * data.size(0)pred = output.argmax(dim=1)correct += pred.eq(target).sum().item()avg_test_loss = test_loss / len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f"Epoch {epoch} 完成 | 測試集損失: {avg_test_loss:.4f} | 測試集準確率: {accuracy:.2f}%")# --- 主執行流程 ---
if __name__ == "__main__":DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")EPOCHS = 10 # 僅作演示,VGG需要更多輪次BATCH_SIZE = 32 # VGG參數量大,減小batch size防止顯存溢出train_loader, test_loader = get_cifar10_loaders(batch_size=BATCH_SIZE)vgg_cbam_model = VGG16_CBAM()run_experiment("VGG16+CBAM", vgg_cbam_model, DEVICE, train_loader, test_loader, EPOCHS)
VGG16+CBAM 微調策略解析
  1. 模型修改 (VGG16_CBAM)

    • 拆分與重組:VGG16的預訓練模型中,特征提取部分model.features是一個包含所有卷積和池化層的nn.Sequential。我們不能直接在中間插入CBAM。因此,我們遍歷了vgg_features中的所有層,以MaxPool2d為界,將它們拆分成了5個卷積塊。

    • 插入CBAM:在每個卷積塊之后,我們都插入了一個對應通道數的CBAM模塊。

    • 保留分類頭:原始的model.classifier(全連接層)被保留,只修改最后一層以適應CIFAR-10的10個類別。

  2. 數據預處理適配

    • VGG16在ImageNet上預訓練時,接收的是224x224的圖像。為了最大化利用預訓練權重,我們在get_cifar10_loaders函數中,通過transforms.Resize(224)將CIFAR-10的32x32圖像放大224x224

  3. 訓練策略:差異化學習率

    • 由于VGG16的參數量巨大(超過1.3億),如果全局使用相同的學習率進行微調,很容易破壞已經學得很好的預訓練權重。

    • 我們采用了一種更精細的差異化學習率 (Differential Learning Rates) 策略:

      • 特征提取層 (model.features):這些是“資深專家”,權重已經很好了,我們給一個極低的學習率1e-5),讓它們只做微小的調整。

      • CBAM模塊 (model.cbam_modules):這些是新加入的“顧問”,需要學習,但不能太激進,給一個中等學習率1e-4)。

      • 分類頭 (model.classifier):這是完全為新任務定制的“新員工”,需要從頭快速學習,給一個較高的學習率1e-3)。

    • 這種策略通過optim.Adam接收一個參數組列表來實現,是微調大型模型時非常有效且常用的高級技巧。

  4. Batch Size調整
    批次大小調整

    • VGG16的參數量和中間激活值都非常大,對顯存的消耗遠超ResNet18。因此,我們將BATCH_SIZE減小到32,以防止顯存溢出(OOM)錯誤。

通過這個實驗,不僅能實踐如何將注意力模塊集成到一個全新的經典網絡(VGG16)中,還能學習到微調大型模型時更高級、更精細的訓練策略,如差異化學習率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/913252.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/913252.shtml
英文地址,請注明出處:http://en.pswp.cn/news/913252.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

docker連接mysql

查看在運行的容器:docker ps -s 進入容器:docker exec -it 容器號或名 /bin/bash,如:docker exec -it c04c438ff177 /bin/bash 或docker exec -it mysql /bin/bash。 3. 登錄mysql:mysql -uroot -p123456

javaweb第182節Linux概述~ 虛擬機連接不上FinalShell

問題描述 虛擬機無法連接到finalshell 報錯 session.connect:java.net.socketexception:connection reset 或者 connection is closed by foreign host 解決 我經過一系列的排查,花費了一天的時間后,發現,只是因為,我將連接…

高壓電纜護層安全的智能防線:TLKS-PLGD 監控設備深度解析

在現代電力系統龐大復雜的網絡中,高壓電纜護層是守護電力傳輸的 "隱形鎧甲",其安全直接影響電網穩定。傳統監測手段響應慢、精度低,難以滿足安全運維需求。TLKS-PLGD 高壓電纜護層環流監控設備應運而生,提供智能化解決方…

Element-Plus Cascader 級聯選擇器獲取節點名稱和value值方法

html 部分 <template><el-cascaderref"selectAeraRef":options"areas":disabled"disabled":props"optionProps"v-model"selectedOptions"filterablechange"handleChange"><template #default"…

STM32中實現shell控制臺(命令解析實現)

文章目錄一、核心設計思想二、命令系統實現詳解&#xff08;含完整注釋&#xff09;1. 示例命令函數實現2. 初始化命令系統3. 命令注冊函數4. 命令查找函數5. 命令執行函數三、命令結構體&#xff08;cmd\_t&#xff09;四、運行效果示例五、小結在嵌入式系統的命令行控制臺&am…

基于matlab的二連桿機械臂PD控制的仿真

基于matlab的二連桿機械臂PD控制的仿真。。。 chap3_5input.m , 1206 d2plant1.m , 1364 hs_err_pid2808.log , 15398 hs_err_pid4008.log , 15494 lx_plot.m , 885 PD_Control.mdl , 35066 tiaojie.m , 737 chap2_1ctrl.asv , 988 chap2_1ctrl.m , 905

TCP、HTTP/1.1 和HTTP/2 協議

TCP、HTTP/1.1 和 HTTP/2 是互聯網通信中的核心協議&#xff0c;它們在網絡分層中處于不同層級&#xff0c;各有特點且逐步演進。以下是它們的詳細對比和關鍵特性&#xff1a;1. TCP&#xff08;傳輸控制協議&#xff09; 層級&#xff1a;傳輸層&#xff08;OSI第4層&#xff…

Java+Vue開發的進銷存ERP系統,集采購、銷售、庫存管理,助力企業數字化運營

前言&#xff1a;在當今競爭激烈的商業環境中&#xff0c;企業對于高效管理商品流通、采購、銷售、庫存以及財務結算等核心業務流程的需求日益迫切。進銷存ERP系統作為一種集成化的企業管理解決方案&#xff0c;能夠整合企業資源&#xff0c;實現信息的實時共享與協同運作&…

【趣談】Android多用戶導致的UserID、UID、shareUserId、UserHandle術語混亂討論

【趣談】Android多用戶導致的UserID、UID、shareUserId、UserHandle術語混亂討論 備注一、概述二、概念對比1.UID2.shareUserId3.UserHandle4.UserID 三、結論 備注 2025/07/02 星期三 在與Android打交道時總遇到UserID、UID、shareUserId、UserHandle這些術語&#xff0c;但是…

P1424 小魚的航程(改進版)

題目描述有一只小魚&#xff0c;它平日每天游泳 250 公里&#xff0c;周末休息&#xff08;實行雙休日)&#xff0c;假設從周 x 開始算起&#xff0c;過了 n 天以后&#xff0c;小魚一共累計游泳了多少公里呢&#xff1f;輸入格式輸入兩個正整數 x,n&#xff0c;表示從周 x 算起…

<二>Sping-AI alibaba 入門-記憶聊天及持久化

請看文檔&#xff0c;流程不再贅述&#xff1a;官網及其示例 簡易聊天 環境變量 引入Spring AI Alibaba 記憶對話還需要我們有數據庫進行存儲&#xff0c;mysql&#xff1a;mysql-connector-java <?xml version"1.0" encoding"UTF-8"?> <pr…

【機器學習深度學習】模型參數量、微調效率和硬件資源的平衡點

目錄 一、核心矛盾是什么&#xff1f; 二、微調本質&#xff1a;不是全調&#xff0c;是“挑著調” 三、如何平衡&#xff1f; 3.1 核心策略 3.2 參數量 vs 微調難度 四、主流輕量微調方案盤點 4.1 凍結部分參數 4.2 LoRA&#xff08;低秩微調&#xff09; 4.3 量化訓…

【V13.0 - 戰略篇】從“完播率”到“價值網絡”:訓練能預測商業潛力的AI矩陣

在上一篇 《超越“平均分”&#xff1a;用多目標預測捕捉觀眾的“心跳曲線”》 中&#xff0c;我們成功地讓AI學會了預測觀眾留存曲線&#xff0c;它的診斷能力已經深入到了視頻的“過程”層面&#xff0c;能精確地指出觀眾是在哪個瞬間失去耐心。 我的AI現在像一個頂級的‘心…

java微服務(Springboot篇)——————IDEA搭建第一個Springboot入門項目

在正文開始之前我們先來解決一些概念性的問題 &#x1f355;&#x1f355;&#x1f355; 問題1&#xff1a;Spring&#xff0c;Spring MVC&#xff0c;Spring Boot和Spring Cloud之間的區別與聯系&#xff1f; &#x1f36c;&#x1f36c;&#x1f36c;&#xff08;1&#xff0…

服務器間接口安全問題的全面分析

一、服務器接口安全核心威脅 文章目錄**一、服務器接口安全核心威脅**![在這里插入圖片描述](https://i-blog.csdnimg.cn/direct/6f54698b9a22439892f0c213bc0fd1f4.png)**二、六大安全方案深度對比****1. IP白名單機制****2. 雙向TLS認證(mTLS)****3. JWT簽名認證****4. OAuth…

vs code關閉函數形參提示

問題&#xff1a;函數內出現灰色的形參提示 需求/矛盾&#xff1a; 這個提示對老牛來說可能是一種干擾&#xff0c;比如不好對齊控制一行代碼的長度&#xff0c;或者容易看走眼&#xff0c;造成眼花繚亂的體驗。 關閉方法&#xff1a; 進入設置&#xff0c;輸入inlay Hints&…

ESXi 8.0安裝

使用群暉&#xff0c;突然nvme固態壞了 新nvme固態&#xff0c;先在PC上格式化下&#xff0c;不然可能N100可能不認 啟動&#xff0c;等待很長時間 回車 F11 輸入密碼&#xff0c;字母小寫字母大寫數字 拔掉U盤&#xff0c;回車重啟 網絡配置 按F2&#xff0c; 輸入密碼&…

【git學習】第2課:查看歷史與版本回退

好的&#xff0c;我們進入 第2課&#xff1a;版本查看與回退機制&#xff0c;本課你將學會如何查看提交歷史、對比更改&#xff0c;并掌握多種回退版本的方法。&#x1f4d8; 第2課&#xff1a;查看歷史與版本回退&#x1f3af; 本課目標熟練查看 Git 提交記錄掌握差異查看、版…

攝像頭AI智能識別工程車技術及應用前景展望

攝像頭AI自動識別工程車是智能交通系統和工程安全管理領域的一項重要技術。它通過圖像識別技術和深度學習算法&#xff0c;實現對工程車的自動檢測和識別&#xff0c;從而提高了施工現場的安全性和管理效率。以下是對該技術及其應用的詳細介紹&#xff1a;一、技術實現數據收集…

Windows服務器安全配置:組策略與權限管理最佳實踐

Windows服務器是企業常用的服務器操作系統&#xff0c;但其開放性和復雜性也使其成為攻擊者的目標。通過正確配置組策略和權限管理&#xff0c;可以有效提高安全性&#xff0c;防止未經授權的訪問和惡意軟件的入侵。以下是詳細的安全配置指南和最佳實踐。 1. 為什么組策略和權限…