python訓練day49 CBAM

import torch
import torch.nn as nn# 定義通道注意力
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):"""通道注意力機制初始化參數:in_channels: 輸入特征圖的通道數ratio: 降維比例,用于減少參數量,默認為16"""super().__init__()# 全局平均池化,將每個通道的特征圖壓縮為1x1,保留通道間的平均值信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局最大池化,將每個通道的特征圖壓縮為1x1,保留通道間的最顯著特征self.max_pool = nn.AdaptiveMaxPool2d(1)# 共享全連接層,用于學習通道間的關系# 先降維(除以ratio),再通過ReLU激活,最后升維回原始通道數self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降維層nn.ReLU(),  # 非線性激活函數nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升維層)# Sigmoid函數將輸出映射到0-1之間,作為各通道的權重self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向傳播函數參數:x: 輸入特征圖,形狀為 [batch_size, channels, height, width]返回:調整后的特征圖,通道權重已應用"""# 獲取輸入特征圖的維度信息,這是一種元組的解包寫法b, c, h, w = x.shape# 對平均池化結果進行處理:展平后通過全連接網絡avg_out = self.fc(self.avg_pool(x).view(b, c))# 對最大池化結果進行處理:展平后通過全連接網絡max_out = self.fc(self.max_pool(x).view(b, c))# 將平均池化和最大池化的結果相加并通過sigmoid函數得到通道權重attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)# 將注意力權重與原始特征相乘,增強重要通道,抑制不重要通道return x * attention #這個運算是pytorch的廣播機制
## 空間注意力模塊
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 通道維度池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)attention = self.conv(pool_out)  # 卷積提取空間特征return x * self.sigmoid(attention)  # 特征與空間權重相乘

## CBAM模塊
class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attn(x)x = self.spatial_attn(x)return x
# 測試下通過CBAM模塊的維度變化
# 輸入卷積的尺寸為
# 假設輸入特征圖:batch=2,通道=512,尺寸=26x26
x = torch.randn(2, 512, 26, 26) 
cbam = CBAM(in_channels=512)
output = cbam(x)  # 輸出形狀不變:(2, 512, 26, 26)
print(f"Output shape: {output.shape}")  # 驗證輸出維度

?cnn+CBAM

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 數據預處理(與原代碼一致)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加載數據集(與原代碼一致)
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 定義帶有CBAM的CNN模型
class CBAM_CNN(nn.Module):def __init__(self):super(CBAM_CNN, self).__init__()# ---------------------- 第一個卷積塊(帶CBAM) ----------------------self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32) # 批歸一化self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2)self.cbam1 = CBAM(in_channels=32)  # 在第一個卷積塊后添加CBAM# ---------------------- 第二個卷積塊(帶CBAM) ----------------------self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)self.cbam2 = CBAM(in_channels=64)  # 在第二個卷積塊后添加CBAM# ---------------------- 第三個卷積塊(帶CBAM) ----------------------self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(kernel_size=2)self.cbam3 = CBAM(in_channels=128)  # 在第三個卷積塊后添加CBAM# ---------------------- 全連接層 ----------------------self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# 第一個卷積塊x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.pool1(x)x = self.cbam1(x)  # 應用CBAM# 第二個卷積塊x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.pool2(x)x = self.cbam2(x)  # 應用CBAM# 第三個卷積塊x = self.conv3(x)x = self.bn3(x)x = self.relu3(x)x = self.pool3(x)x = self.cbam3(x)  # 應用CBAM# 全連接層x = x.view(-1, 128 * 4 * 4)x = self.fc1(x)x = self.relu3(x)x = self.dropout(x)x = self.fc2(x)return x# 初始化模型并移至設備
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
# 訓練函數
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()all_iter_losses = []iter_indices = []train_acc_history = []test_acc_history = []train_loss_history = []test_loss_history = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 單Batch損失: {iter_loss:.4f} | 累計平均損失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 測試階段model.eval()test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc# 繪圖函數
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('每個 Iteration 的訓練損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('訓練和測試準確率')plt.legend()plt.grid(True)plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('訓練和測試損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 執行訓練
epochs = 50
print("開始使用帶CBAM的CNN訓練模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")# # 保存模型
# torch.save(model.state_dict(), 'cifar10_cbam_cnn_model.pth')
# print("模型已保存為: cifar10_cbam_cnn_model.pth")

@浙大疏錦行?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/87601.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/87601.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/87601.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

在小程序中實現實時聊天:WebSocket最佳實踐

前言 在當今互聯網應用中,實時通信已經成為一個標配功能,特別是對于需要即時響應的場景,如在線客服、咨詢系統等。本文將分享如何在小程序中實現一個高效穩定的WebSocket連接,以及如何處理斷線重連、消息發送與接收等常見問題。 W…

Python網絡爬蟲編程新手篇

網絡爬蟲是一種自動抓取互聯網信息的腳本程序,廣泛應用于搜索引擎、數據分析和內容聚合。這次我將帶大家使用Python快速構建一個基礎爬蟲,為什么使用python做爬蟲?主要就是支持的庫很多,而且同類型查詢文檔多,在同等情…

LeetCode.283移動零

題目鏈接:283. 移動零 - 力扣(LeetCode) 題目描述: 給定一個數組 nums,編寫一個函數將所有 0 移動到數組的末尾,同時保持非零元素的相對順序。 請注意 ,必須在不復制數組的情況下原地對數組進行…

2025年7月4日漏洞文字版表述一句話版本(漏洞危害以及修復建議),通常用于漏洞通報中簡潔干練【持續更新中】,漏洞通報中對于各類漏洞及修復指南

漏洞及修復指南 一、暗鏈 危害:攻擊者通過技術手段在用戶網頁中插入隱藏鏈接或代碼,并指向惡意網站,可導致用戶信息泄露、系統感染病毒,用戶訪問被劫持至惡意網站,泄露隱私或感染惡意軟件,被黑客利用進行…

python --飛漿離線ocr使用/paddleocr

依賴 # python3.7.3 paddleocr2.7.0.2 paddlepaddle2.5.2 loguru0.7.3from paddleocr import PaddleOCR import cv2 import numpy as npif __name__ __main__:OCR PaddleOCR(use_doc_orientation_classifyFalse, # 檢測文檔方向use_doc_unwarpingFalse, # 矯正扭曲文檔use…

數據結構與算法:貪心(三)

前言 感覺開始打cf了以后貪心的能力有了明顯的提升,讓我們謝謝cf的感覺場。 一、跳躍游戲 II class Solution { public:int jump(vector<int>& nums) {int n=nums.size();//怎么感覺這個題也在洛谷上刷過(?)int cur=0;//當前步最遠位置int next=0;//多跳一步最遠…

【Redis篇】數據庫架構演進中Redis緩存的技術必然性—高并發場景下穿透、擊穿、雪崩的體系化解決方案

&#x1f4ab;《博主主頁》&#xff1a;    &#x1f50e; CSDN主頁__奈斯DB    &#x1f50e; IF Club社區主頁__奈斯、 &#x1f525;《擅長領域》&#xff1a;擅長阿里云AnalyticDB for MySQL(分布式數據倉庫)、Oracle、MySQL、Linux、prometheus監控&#xff1b;并對…

Docker 實踐與應用案例

引言 在當今的軟件開發和部署領域&#xff0c;高效、可移植且一致的環境搭建與應用部署是至關重要的。Docker 作為一款輕量級的容器化技術&#xff0c;為解決這些問題提供了卓越的方案。Docker 通過容器化的方式&#xff0c;將應用及其依賴項打包成一個獨立的容器&#xff0c;…

《論三生原理》以非共識路徑實現技術代際躍遷??

AI輔助創作&#xff1a; 《論三生原理》以顛覆傳統數學范式的非共識路徑驅動多重技術代際躍遷&#xff0c;其突破性實踐與爭議并存&#xff0c;核心論證如下&#xff1a; 一、技術代際躍遷的實證突破? ?芯片架構革新? 為華為三進制邏輯門芯片提供理論支撐&#xff0c;通過對…

一體機電腦為何熱度持續上升?消費者更看重哪些功能?

一體機電腦&#xff08;AIO&#xff0c;All-in-One&#xff09;將主機硬件與顯示器集成于單一機身。通常僅需連接電源線&#xff0c;配備無線鍵盤、鼠標即可啟用。相比傳統臺式電腦和筆記本電腦&#xff0c;選購一體機的客戶更看重一體機的以下特點。 一體機憑借其節省空間、簡…

無人機載重模塊技術要點分析

一、技術要點 1. 結構設計創新 雙電機卷揚系統&#xff1a;采用主電機&#xff08;張力控制&#xff09;和副電機&#xff08;卷揚控制&#xff09;協同工作&#xff0c;解決繩索纏繞問題&#xff0c;支持30米繩長1.2m/s高速收放&#xff0c;重載穩定性提升。 軸雙槳布局…

【大模型推理】工作負載的彈性伸縮

基于Knative的LLM推理場景彈性伸縮方案 1.QPS 不是一個好的 pod autoscaling indicator 在LLM推理中&#xff0c; 為什么 2. concurrency適用于單次請求資源消耗大且處理時間長的業務&#xff0c;而rps則適合較短處理時間的業務。 3.“反向彈性伸縮”的概念 4。 區分兩種不同的…

STM32F103_Bootloader程序開發12 - IAP升級全流程

導言 本教程使用正點原子戰艦板開發。 《STM32F103_Bootloader程序開發11 - 實現 App 安全跳轉至 Bootloader》上一章節實現App跳轉bootloader&#xff0c;接著&#xff0c;跳轉到bootloader后&#xff0c;下位機要發送報文‘C’給IAP上位機&#xff0c;表示我準備好接收固件數…

AI驅動的未來軟件工程范式

引言&#xff1a;邁向智能驅動的軟件工程新范式 本文是一份關于構建和實施“AI驅動的全生命周期軟件工程范式”的簡要集成指南。它旨在提供一個獨立、完整、具體的框架&#xff0c;指導組織如何將AI智能體深度融合到軟件開發的每一個環節&#xff0c;實現從概念到運維的智能化…

Hawk Insight|美國6月非農數據點評:情況遠沒有看上去那么好

7月3日&#xff0c;美國近期最重要的勞動力數據——6月非農數據公布。在ADP遇冷之后&#xff0c;市場對這份報告格外期待。 根據美國勞工統計局公布報告&#xff0c;美國6月非農就業人口增加 14.7萬人&#xff0c;預期 10.6萬人&#xff0c;4月和5月非農就業人數合計上修1.6萬人…

Python 的內置函數 reversed

Python 內建函數列表 > Python 的內置函數 reversed Python 的內置函數 reversed() 是一個用于序列反轉的高效工具函數&#xff0c;它返回一個反向迭代器對象。以下是關于該函數的詳細說明&#xff1a; 基本用法 語法&#xff1a;reversed(seq)參數&#xff1a;seq 可以是…

溝通-交流-說話-gt-jl-sh-goutong-jiaoliu-shuohua

溝通,先看|問狀態(情緒) 老婆下班回家,我說,到哪兒了,買點玉米哦;她說你為啥不買, 我說怎么如此大火氣, 她說你安排我&#xff0c;我不情愿;你怎么看 和女人溝通不能目標優先 先問狀態并表達關心 用感謝代替要求&#xff08;“你上次買的玉米特別甜&#xff0c;今天突然又饞了…

Ubuntu20.04運DS-5

準備工作&#xff1a; cd /home/rlk/rlk/runninglinuxkernel_5.0 #make clean mkdir _install_arm64/dev sudo mknod _install_arm64/dev/console c 5 1 ./build_ds5_arm64.sh git checkout boot-wrapper-aarch64/fvp-base-gicv3-psci.dtb ./build_ds5_arm64.sh創建工程步驟2.5…

區塊鏈網絡P2P通信原理

目錄 區塊鏈網絡P2P通信原理引言:去中心化的網絡基石1. P2P網絡基礎架構1.1 區塊鏈網絡拓撲1.2 節點類型對比2. 節點發現與連接2.1 初始引導過程2.2 節點發現協議3. 網絡通信協議3.1 消息結構3.2 核心消息類型4. 數據傳播機制4.1 交易傳播流程4.2 Gossip協議實現4.3 區塊傳播優…

RNN和Transformer區別

RNN&#xff08;循環神經網絡&#xff09;和 Transformer 是兩種廣泛應用于自然語言處理&#xff08;NLP&#xff09;和其他序列任務的深度學習架構。它們在設計理念、性能特點和應用場景上存在顯著區別。以下是它們的詳細對比&#xff1a;1. 基本架構RNN&#xff08;循環神經網…