Python day50

@浙大疏錦行?python day50.

  • 在預訓練模型(resnet18)中添加cbam注意力機制,需要修改模型的架構,同時應該考慮插入的cbam注意力機制模塊的位置;
import torch
import torch.nn as nn
from torchvision import models# 自定義ResNet18模型,插入CBAM模塊
class ResNet18_CBAM(nn.Module):def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7):super().__init__()# 加載預訓練ResNet18self.backbone = models.resnet18(pretrained=pretrained) # 修改首層卷積以適應32x32輸入(CIFAR10)self.backbone.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)self.backbone.maxpool = nn.Identity()  # 移除原始MaxPool層(因輸入尺寸小)# 在每個殘差塊組后添加CBAM模塊self.cbam_layer1 = CBAM(in_channels=64, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer2 = CBAM(in_channels=128, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer3 = CBAM(in_channels=256, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer4 = CBAM(in_channels=512, ratio=cbam_ratio, kernel_size=cbam_kernel)# 修改分類頭self.backbone.fc = nn.Linear(in_features=512, out_features=num_classes)def forward(self, x):# 主干特征提取x = self.backbone.conv1(x)x = self.backbone.bn1(x)x = self.backbone.relu(x)  # [B, 64, 32, 32]# 第一層殘差塊 + CBAMx = self.backbone.layer1(x)  # [B, 64, 32, 32]x = self.cbam_layer1(x)# 第二層殘差塊 + CBAMx = self.backbone.layer2(x)  # [B, 128, 16, 16]x = self.cbam_layer2(x)# 第三層殘差塊 + CBAMx = self.backbone.layer3(x)  # [B, 256, 8, 8]x = self.cbam_layer3(x)# 第四層殘差塊 + CBAMx = self.backbone.layer4(x)  # [B, 512, 4, 4]x = self.cbam_layer4(x)# 全局平均池化 + 分類x = self.backbone.avgpool(x)  # [B, 512, 1, 1]x = torch.flatten(x, 1)  # [B, 512]x = self.backbone.fc(x)  # [B, 10]return x# 初始化模型并移至設備
model = ResNet18_CBAM().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
  • 修改模型結構后,需要考慮模型訓練的策略,一般來說可以先凍結原有的部分進行訓練以期待新增的部分可以獲得一個不錯的表現;之后解凍原有部分中的高層layer并賦予一個較低的學習率來保證不會出現不應該的錯誤;最后解凍所有參數,也是賦予較低的學習率,來學習最終的端到端任務。
def set_trainable_layers(model, trainable_parts):print(f"\n---> 解凍以下部分并設為可訓練: {trainable_parts}")for name, param in model.named_parameters():param.requires_grad = Falsefor part in trainable_parts:if part in name:param.requires_grad = Truebreakdef train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs):optimizer = None# 初始化歷史記錄列表,與你的要求一致all_iter_losses, iter_indices = [], []train_acc_history, test_acc_history = [], []train_loss_history, test_loss_history = [], []for epoch in range(1, epochs + 1):epoch_start_time = time.time()# --- 動態調整學習率和凍結層 ---if epoch == 1:print("\n" + "="*50 + "\n🚀 **階段 1:訓練注意力模塊和分類頭**\n" + "="*50)set_trainable_layers(model, ["cbam", "backbone.fc"])optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)elif epoch == 6:print("\n" + "="*50 + "\n?? **階段 2:解凍高層卷積層 (layer3, layer4)**\n" + "="*50)set_trainable_layers(model, ["cbam", "backbone.fc", "backbone.layer3", "backbone.layer4"])optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)elif epoch == 21:print("\n" + "="*50 + "\n🛰? **階段 3:解凍所有層,進行全局微調**\n" + "="*50)for param in model.parameters(): param.requires_grad = Trueoptimizer = optim.Adam(model.parameters(), lr=1e-5)# --- 訓練循環 ---model.train()running_loss, correct, total = 0.0, 0, 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 記錄每個iteration的損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append((epoch - 1) * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 按你的要求,每100個batch打印一次if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 單Batch損失: {iter_loss:.4f} | 累計平均損失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_loss_history.append(epoch_train_loss)train_acc_history.append(epoch_train_acc)# --- 測試循環 ---model.eval()test_loss, correct_test, total_test = 0, 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_loss_history.append(epoch_test_loss)test_acc_history.append(epoch_test_acc)# 打印每個epoch的最終結果print(f'Epoch {epoch}/{epochs} 完成 | 耗時: {time.time() - epoch_start_time:.2f}s | 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%')# 訓練結束后調用繪圖函數print("\n訓練完成! 開始繪制結果圖表...")plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)# 返回最終的測試準確率return epoch_test_accmodel = ResNet18_CBAM().to(device)
criterion = nn.CrossEntropyLoss()
epochs = 50print("開始使用帶分階段微調策略的ResNet18+CBAM模型進行訓練...")
final_accuracy = train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")torch.save(model.state_dict(), 'resnet18_cbam_finetuned.pth')
print("模型已保存為: resnet18_cbam_finetuned.pth")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/96216.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/96216.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/96216.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

VPS海外節點性能監控全攻略:從基礎配置到高級優化

在全球化業務部署中,VPS海外節點的穩定運行直接影響用戶體驗。本文將深入解析如何構建高效的性能監控體系,涵蓋網絡延遲檢測、資源閾值設置、告警機制優化等核心環節,幫助運維人員實現跨國服務器的可視化管控。 VPS海外節點性能監控全攻略&am…

C語言初學者筆記【結構體】

文章目錄一、結構體的使用1. 結構體聲明2. 變量創建與初始化3. 特殊聲明與陷阱二、內存對齊1. 規則:2. 示例分析:3. 修改默認對齊數:三、結構體傳參四、結構體實現位段1. 定義2. 內存分配3. 應用場景4. 跨平臺問題:5. 注意事項&am…

基于XGBoost算法的數據回歸預測 極限梯度提升算法 XGBoost

一、作品詳細簡介 1.1附件文件夾程序代碼截圖 全部完整源代碼,請在個人首頁置頂文章查看: 學行庫小秘_CSDN博客?編輯https://blog.csdn.net/weixin_47760707?spm1000.2115.3001.5343 1.2各文件夾說明 1.2.1 main.m主函數文件 該MATLAB 代碼實現了…

數據安全系列4:常用的對稱算法淺析

常用的算法介紹 常用的算法JAVA實現 jce及其它開源包介紹、對比 傳送門 數據安全系列1:開篇 數據安全系列2:單向散列函數概念 數據安全系列3:密碼技術概述 時代有浪潮,就有退去的時候 在我的博客文章里面,其中…

云計算學習100天-第26天

地址重寫地址重寫語法——關于Nginx服務器的地址重寫,主要用到的配置參數是rewrite 語法格式: rewrite regex replacement flag rewrite 舊地址 新地址 [選項]地址重寫步驟:#修改配置文件(訪問a.html重定向到b.html) cd /usr/local/ngin…

【Python辦公】字符分割拼接工具(GUI工具)

目錄 專欄導讀 項目簡介 功能特性 ?? 核心功能 1. 字符分割功能 2. 字符拼接功能 ?? 界面特性 現代化設計 用戶體驗優化 技術實現 開發環境 核心代碼結構 關鍵技術點 使用指南 安裝步驟 完整代碼 字符分割操作 字符拼接操作 應用場景 數據處理 文本編輯 開發輔助 項目優勢 …

Windows 命令行:dir 命令

專欄導航 上一篇:Windows 命令行:Exit 命令 回到目錄 下一篇:MFC 第一章概述 本節前言 學習本節知識,需要你首先懂得如何打開一個命令行界面,也就是命令提示符界面。鏈接如下。 參考課節:Windows 命令…

軟考高級--系統架構設計師--案例分析真題解析

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔 文章目錄前言試題一 軟件架構設計一、2019年 案例分析二、2020年 案例分析三、2021年 案例分析四、2022年 案例分析試題二 軟件系統設計一、2019年 案例分析二、2020年 案例分…

css中的性能優化之content-visibility: auto

content-visibility: auto的核心機制是讓瀏覽器智能跳過屏幕外元素的渲染工作,包括布局和繪制,直到它們接近視口時才渲染。這與虛擬滾動等傳統方案相比優勢明顯,只需要一行CSS就能實現近似效果。值得注意的是必須配合contain-intrinsic-size屬…

通過uniapp將vite vue3項目打包為android系統的.apk包,并實現可自動升級功能

打包vue項目,注意vite.config.ts文件和路由文件設置 vite.config.ts,將base等配置改為./ import {fileURLToPath, URL } from node:urlimport {defineConfig } from vite import vue from @vitejs/plugin-vue import AutoImport from unplugin-auto-import/vite import Com…

經營幫租賃經營板塊:解鎖資產運營新生態,賦能企業增長新引擎

在商業浪潮奔涌向前的當下,企業資產運營與租賃管理的模式不斷迭代,“經營幫” 以其租賃經營板塊為支點,構建起涵蓋多元業務場景、適配不同需求的生態體系,成為眾多企業破局資產低效困局、挖掘增長新動能的關鍵助力。本文將深度拆解…

C語言---編譯的最小單位---令牌(Token)

文章目錄C語言中令牌幾類令牌是編譯器理解源代碼的最小功能單元,是編譯過程的第一步。C語言中令牌幾類 1、關鍵字: 具有固定含義的保留字,如 int, if, for, while, return 等。 2、標識符: 由程序員定義的名稱,用于變…

機器學習 | Python中進行特征重要性分析的9個常用方法

在Python中,特征重要性分析是機器學習模型解釋和特征選擇的關鍵步驟。以下是9種常用方法及其實現示例: 1. 基于樹的模型內置特征重要性 原理:樹模型(如隨機森林、XGBoost)根據特征分裂時的純度提升(基尼不純度/信息增益)計算重要性。 from sklearn.ensemble import Ra…

心路歷程-了解網絡相關知識

在做這個題材的時候,考慮的一個點就是:自己的最初的想法;可是技術是不斷更新的; 以前的材料會落后,但是萬變不能變其中;所以呈現出來的知識點也相對比較老舊,為什么呢? 因為最新的素…

CAT1+mqtt

文章目錄 MQTT知識點mqtt數據固定報頭可變報頭(連接請求)有效載荷 阿里云MQTT測試訂閱Topic下發數據給MQTT.fxMQTT.fx 發布消息給服務器 下載mqtt(C-嵌入式版)我的W5500項目路徑使用Cat1連接阿里云平臺AT指令串口連接1. 開機聯網2. 激活內置SIM卡(貼片卡)3. 我這里使用連接的是…

AiPPT怎么樣?好用嗎?

AiPPT怎么樣?好用嗎?AiPPT 是一款智能高效的PPT生成工具,通過AI技術快速將主題或文檔(如Word/PDF)轉化為專業PPT,提供超10萬套行業模板,覆蓋商務、教育等22場景,支持一鍵生成大綱、文…

惡補DSP:2.F28335的定時器系統

一、定時器原理F28335 城市的三座時鐘塔(Timer0、Timer1、Timer2)是城市時間管理的核心設施,每座均為32位精度,依靠城市能源脈沖(系統時鐘 SYSCLKOUT,典型頻率為150 MHz)驅動。它們由兩個核心模…

用倒計時軟件為考研備考精準導航 復習 模擬考試 日期倒計時都可以用

考研,是一場與時間的博弈。從決定報名的那一刻起,日歷上的每一個數字都被賦予了特殊意義 —— 報名截止日、現場確認期、初試倒計時、成績查詢點…… 這些節點如同航標,指引著備考者的方向。而在這場漫長的征途里,一款精準、易用的…

React學習(七)

目錄:1.react-進階-antd-搜索2.react-進階-antd-依賴項說明 3.react-進階-antd-刪除1.react-進階-antd-搜索我們jsx代碼里只能返回一個最頂層的根元素下拉框簡化寫法:把這個對象結構賦值一下:清空定義個參數類型做修改事件需要定義三個…

Unix Domain Socket(UDS)和 TCP/IP(使用 127.0.0.1)進程間通信(IPC)的比較

Unix Domain Socket(UDS)和 TCP/IP(使用 127.0.0.1 或 localhost)都是進程間通信(IPC)的方式,但它們在實現、性能和適用場景上有顯著區別。以下是兩者的對比:1. 通信機制Unix Domain…