Python訓練營打卡Day44-通道注意力(SE注意力)

知識點回顧:

  1. 不同CNN層的特征圖:不同通道的特征圖
  2. 什么是注意力:注意力家族,類似于動物園,都是不同的模塊,好不好試了才知道。
  3. 通道注意力:模型的定義和插入的位置
  4. 通道注意力后的特征圖和熱力圖

內容參考

作業:

  1. 今日代碼較多,理解邏輯即可
  2. 對比不同卷積層特征圖可視化的結果(可選)

一、 什么是注意力

其中注意力機制是一種讓模型學會「選擇性關注重要信息」的特征提取器,就像人類視覺會自動忽略背景,聚焦于圖片中的主體(如貓、汽車)。

transformer中的叫做自注意力機制,他是一種自己學習自己的機制,他可以自動學習到圖片中的主體,并忽略背景。我們現在說的很多模塊,比如通道注意力、空間注意力、通道注意力等等,都是基于自注意力機制的。

從數學角度看,注意力機制是對輸入特征進行加權求和,輸出=∑(輸入特征×注意力權重),其中注意力權重是學習到的。所以他和卷積很像,因為卷積也是一種加權求和。但是卷積是 “固定權重” 的特征提取(如 3x3 卷積核)--訓練完了就結束了,注意力是 “動態權重” 的特征提取(權重隨輸入數據變化)---輸入數據不同權重不同。

問:為什么需要多種注意力模塊?

答:因為不同場景下的關鍵信息分布不同。例如,識別鳥類和飛機時,需關注 “羽毛紋理”“金屬光澤” 等特定通道的特征,通道注意力可強化關鍵通道;而物體位置不確定時(如貓出現在圖像不同位置),空間注意力能聚焦物體所在區域,忽略背景。復雜場景中,可能需要同時關注通道和空間(如混合注意力模塊 CBAM),或處理長距離依賴(如全局注意力模塊 Non-local)。

問:為什么不設計一個萬能注意力模塊?

答:主要受效率和靈活性限制。專用模塊針對特定需求優化計算,成本更低(如通道注意力僅需處理通道維度,無需全局位置計算);不同任務的核心需求差異大(如醫學圖像側重空間定位,自然語言處理側重語義長距離依賴),通用模塊可能冗余或低效。每個模塊新增的權重會增加模型參數量,若訓練數據不足或優化不當,可能引發過擬合。因此實際應用中需結合輕量化設計(如減少全連接層參數)、正則化(如 Dropout)或結構約束(如共享注意力權重)來平衡性能與復雜度。

通道注意力(Channel Attention)屬于注意力機制(Attention Mechanism)的變體,而非自注意力(Self-Attention)的直接變體。可以理解為注意力是一個動物園算法,里面很多個物種,自注意力只是一個分支,因為開創了transformer所以備受矚目。我們今天的內容用通道注意力舉例。

常見注意力模塊的歸類如下

二、 特征圖的提取

2.2 特征圖可視化

1.初始化設置:將模型設為評估模式,準備類別名稱列表(如飛機、汽車等)。

2.數據加載與處理:

①從測試數據加載器中獲取圖像和標簽。

②僅處理前 num_images?張圖像(如2張)。

3.注冊鉤子捕獲特征圖:

①為指定層(如 conv1, conv2, conv3)注冊前向鉤子。

②鉤子函數將這些層的輸出(特征圖)保存到字典中。

4.前向傳播與特征提取:

①模型處理圖像,觸發鉤子函數,獲取并保存特征圖。

②移除鉤子,避免后續干擾。

5.可視化特征圖:

?對每張圖像

?①恢復原始像素值并顯示。

?②為每個目標層創建子圖,展示前 num_channels?個通道的特征圖(如9個通道)。

?③每個通道的特征圖以網格形式排列,顯示通道編號。

關鍵細節

①特征圖布局:原始圖像在左側,各層特征圖按順序排列在右側。

②通道選擇:默認顯示前9個通道(按重要性或索引排序)。

③顯示優化:

使用 inset_axes?在大圖中嵌入小網格,清晰展示每個通道;

層標題與通道標題分開,避免重疊;

反標準化處理恢復圖像原始色彩。

def visualize_feature_maps(model, test_loader, device, layer_names, num_images=3, num_channels=9):"""可視化指定層的特征圖(修復循環冗余問題)參數:model: 模型test_loader: 測試數據加載器layer_names: 要可視化的層名稱(如['conv1', 'conv2', 'conv3'])num_images: 可視化的圖像總數num_channels: 每個圖像顯示的通道數(取前num_channels個通道)"""model.eval()  # 設置為評估模式class_names = ['飛機', '汽車', '鳥', '貓', '鹿', '狗', '青蛙', '馬', '船', '卡車']# 從測試集加載器中提取指定數量的圖像(避免嵌套循環)images_list, labels_list = [], []for images, labels in test_loader:images_list.append(images)labels_list.append(labels)if len(images_list) * test_loader.batch_size >= num_images:break# 拼接并截取到目標數量images = torch.cat(images_list, dim=0)[:num_images].to(device)labels = torch.cat(labels_list, dim=0)[:num_images].to(device)with torch.no_grad():# 存儲各層特征圖feature_maps = {}# 保存鉤子句柄hooks = []# 定義鉤子函數,捕獲指定層的輸出def hook(module, input, output, name):feature_maps[name] = output.cpu()  # 保存特征圖到字典# 為每個目標層注冊鉤子,并保存鉤子句柄for name in layer_names:module = getattr(model, name)hook_handle = module.register_forward_hook(lambda m, i, o, n=name: hook(m, i, o, n))hooks.append(hook_handle)# 前向傳播觸發鉤子_ = model(images)# 正確移除鉤子for hook_handle in hooks:hook_handle.remove()# 可視化每個圖像的各層特征圖(僅一層循環)for img_idx in range(num_images):img = images[img_idx].cpu().permute(1, 2, 0).numpy()# 反標準化處理(恢復原始像素值)img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)  # 確保像素值在[0,1]范圍內# 創建子圖num_layers = len(layer_names)fig, axes = plt.subplots(1, num_layers + 1, figsize=(4 * (num_layers + 1), 4))# 顯示原始圖像axes[0].imshow(img)axes[0].set_title(f'原始圖像\n類別: {class_names[labels[img_idx]]}')axes[0].axis('off')# 顯示各層特征圖for layer_idx, layer_name in enumerate(layer_names):fm = feature_maps[layer_name][img_idx]  # 取第img_idx張圖像的特征圖fm = fm[:num_channels]  # 僅取前num_channels個通道num_rows = int(np.sqrt(num_channels))num_cols = num_channels // num_rows if num_rows != 0 else 1# 創建子圖網格layer_ax = axes[layer_idx + 1]layer_ax.set_title(f'{layer_name}特征圖 \n')# 加個換行讓文字分離上去layer_ax.axis('off')  # 關閉大子圖的坐標軸# 在大子圖內創建小網格for ch_idx, channel in enumerate(fm):ax = layer_ax.inset_axes([ch_idx % num_cols / num_cols, (num_rows - 1 - ch_idx // num_cols) / num_rows, 1/num_cols, 1/num_rows])ax.imshow(channel.numpy(), cmap='viridis')ax.set_title(f'通道 {ch_idx + 1}')ax.axis('off')plt.tight_layout()plt.show()# 調用示例(按需修改參數)
layer_names = ['conv1', 'conv2', 'conv3']
visualize_feature_maps(model=model,test_loader=test_loader,device=device,layer_names=layer_names,num_images=5,  # 可視化5張測試圖像 → 輸出5張大圖num_channels=9   # 每張圖像顯示前9個通道的特征圖
)

上面的圖為提取CNN不同卷積層輸出的特征圖,我們以第五張圖片-青蛙 進行解讀。

由于經過了不斷的下采樣,特征變得越來越抽象,人類已經無法理解。

?核心作用

通過可視化特征圖,可直觀觀察:

①淺層卷積層(如 conv1)如何捕獲邊緣、紋理等低級特征。

②深層卷積層(如 conv3 )如何組合低級特征形成語義概念(如物體部件)。

③模型對不同類別的關注區域差異(如鳥類的羽毛紋理 vs. 飛機的金屬光澤)。

conv1 特征圖(淺層卷積)

特點:

①保留較多原始圖像的細節紋理(如植物葉片、青蛙身體的邊緣輪廓)。

②通道間差異相對小,每個通道都能看到類似原始圖像的基礎結構(如通道 1 - 9 都能識別邊緣、紋理)。

意義:

①提取低級特征(邊緣、顏色塊、簡單紋理),是后續高層特征的“原材料”。

②類似人眼初步識別圖像的輪廓和基礎結構。

conv2 特征圖(中層卷積)

特點:

①空間尺寸(高、寬)比 conv1 更小(因卷積/池化下采樣),但語義信息更抽象。

②通道間差異更明顯:部分通道開始聚焦局部關鍵特征(如通道 5、8 中黃色高亮區域,可能對應青蛙身體或植物的關鍵紋理)。

意義:

①對 conv1 的低級特征進行組合與篩選,提取中級特征(如局部形狀、紋理組合)。

②類似人眼從“邊緣輪廓”過渡到“識別局部結構”(如青蛙的身體塊、植物的葉片簇)。

?conv3 特征圖(深層卷積)

特點:

①空間尺寸進一步縮小,抽象程度最高,肉眼難直接對應原始圖像細節。

②通道間差異極大,部分通道聚焦全局語義特征(如通道 4、7 中黃色區域,可能對應模型判斷“青蛙”類別的關鍵特征)。

意義:

①對 conv2 的中級特征進行全局整合,提取高級語義特征(如物體類別相關的抽象模式)。

②類似人眼最終“識別出這是青蛙”的關鍵依據,模型通過這些特征判斷類別。

三、通道注意力

3.1 通道注意力的定義

# ===================== 新增:通道注意力模塊(SE模塊) =====================
class ChannelAttention(nn.Module):"""通道注意力模塊(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_ratio=16):"""參數:in_channels: 輸入特征圖的通道數reduction_ratio: 降維比例,用于減少參數量"""super(ChannelAttention, self).__init__()# 全局平均池化 - 將空間維度壓縮為1x1,保留通道信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全連接層 + 激活函數,用于學習通道間的依賴關系self.fc = nn.Sequential(# 降維:壓縮通道數,減少計算量nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),nn.ReLU(inplace=True),# 升維:恢復原始通道數nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),# Sigmoid將輸出值歸一化到[0,1],表示通道重要性權重nn.Sigmoid())def forward(self, x):"""參數:x: 輸入特征圖,形狀為 [batch_size, channels, height, width]返回:加權后的特征圖,形狀不變"""batch_size, channels, height, width = x.size()# 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1]avg_pool_output = self.avg_pool(x)# 2. 展平為一維向量:[batch_size, channels, 1, 1] → [batch_size, channels]avg_pool_output = avg_pool_output.view(batch_size, channels)# 3. 通過全連接層學習通道權重:[batch_size, channels] → [batch_size, channels]channel_weights = self.fc(avg_pool_output)# 4. 重塑為二維張量:[batch_size, channels] → [batch_size, channels, 1, 1]channel_weights = channel_weights.view(batch_size, channels, 1, 1)# 5. 將權重應用到原始特征圖上(逐通道相乘)return x * channel_weights  # 輸出形狀:[batch_size, channels, height, width]

3.2 模型的重新定義(通道注意力的插入)

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()  # ---------------------- 第一個卷積塊 ----------------------self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()# 新增:插入通道注意力模塊(SE模塊)self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16)  self.pool1 = nn.MaxPool2d(2, 2)  # ---------------------- 第二個卷積塊 ----------------------self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()# 新增:插入通道注意力模塊(SE模塊)self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16)  self.pool2 = nn.MaxPool2d(2)  # ---------------------- 第三個卷積塊 ----------------------self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()# 新增:插入通道注意力模塊(SE模塊)self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16)  self.pool3 = nn.MaxPool2d(2)  # ---------------------- 全連接層(分類器) ----------------------self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):# ---------- 卷積塊1處理 ----------x = self.conv1(x)       x = self.bn1(x)         x = self.relu1(x)       x = self.ca1(x)  # 應用通道注意力x = self.pool1(x)       # ---------- 卷積塊2處理 ----------x = self.conv2(x)       x = self.bn2(x)         x = self.relu2(x)       x = self.ca2(x)  # 應用通道注意力x = self.pool2(x)       # ---------- 卷積塊3處理 ----------x = self.conv3(x)       x = self.bn3(x)         x = self.relu3(x)       x = self.ca3(x)  # 應用通道注意力x = self.pool3(x)       # ---------- 展平與全連接層 ----------x = x.view(-1, 128 * 4 * 4)  x = self.fc1(x)           x = self.relu3(x)         x = self.dropout(x)       x = self.fc2(x)           return x  # 重新初始化模型,包含通道注意力模塊
model = CNN()
model = model.to(device)  # 將模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 引入學習率調度器,在訓練過程中動態調整學習率--訓練初期使用較大的 LR 快速降低損失,訓練后期使用較小的 LR 更精細地逼近全局最優解。
# 在每個 epoch 結束后,需要手動調用調度器來更新學習率,可以在訓練過程中調用 scheduler.step()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,        # 指定要控制的優化器(這里是Adam)mode='min',       # 監測的指標是"最小化"(如損失函數)patience=3,       # 如果連續3個epoch指標沒有改善,才降低LRfactor=0.5        # 降低LR的比例(新LR = 舊LR × 0.5)
)# 訓練模型(復用原有的train函數)
print("開始訓練帶通道注意力的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/94679.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/94679.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/94679.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

shiro進行解密

目錄Shiro 解密的核心注意事項1. 密碼處理:堅決避免 “可逆解密”2.例子【自己模擬數據庫,未連數據庫】:Shiro 解密的核心注意事項 1. 密碼處理:堅決避免 “可逆解密” 禁用明文存儲:永遠不要將明文密碼存入數據庫,必須使用 Has…

更改 Microsoft Edge 瀏覽器的緩存與用戶數據目錄位置

Microsoft Edge瀏覽器默認會將緩存文件和用戶數據存儲在系統盤(通常是C盤),隨著使用時間的增長,這些文件可能會占用大量空間。本文將詳細介紹多種更改Edge瀏覽器緩存位置和用戶數據目錄位置的方法,幫助您更好地管理磁盤…

【傳奇開心果系列】Flet框架實現的圖形化界面的PDF轉word轉換器辦公小工具自定義模板

let框架實現的圖形化界面的PDF轉word轉換器辦公小工具自定義模板一、效果展示截圖二、PDF轉Word轉換器概括介紹三、功能特性四、安裝依賴五、運行程序六、使用說明七、注意事項八、技術棧九、系統要求十、源碼下載地址 一、效果展示截圖二、PDF轉Word轉換器概括介紹 一個基于Fl…

STM32 定時器(PWM輸入捕獲)

以下是基于STM32標準庫(以STM32F103為例)實現PWM輸入模式(自動雙沿捕獲)的完整代碼,通過配置定時器的PWM輸入模式,可自動捕獲外部PWM信號的周期(頻率)?和占空比,無需手動…

Web安全開發指導規范文檔V1.0

一、背景 團隊最近頻繁遭受網絡攻擊,引起了部門技術負責人的重視,筆者在團隊中相對來說更懂安全,因此花了點時間編輯了一份安全開發自檢清單,覺得應該也有不少讀者有需要,所以將其分享出來。 二、編碼安全 2.1 輸入驗證 說明 檢查項 概述 任何來自客戶端的數據,如URL和…

在Godot中為您的游戲添加并控制游戲角色的完整技術指南

這是一個在Godot中為您的游戲添加并控制玩家角色的完整技術指南。這個過程分為三大步:?準備資源、構建場景、編寫控制腳本。道可道,非常道,名可名,非常名!第一步:準備資源(建模與動畫&#xff…

Flink 狀態 RocksDBListState(寫入時的Merge優化)

RocksDBListState<K, N, V> RocksDBListState 繼承自 AbstractRocksDBState<K, N, List<V>>&#xff0c;并實現了 InternalListState<K, N, V> 接口。繼承 AbstractRocksDBState: 這意味著它天然獲得了與 RocksDB 交互的底層能力&#xff0c;包括&…

zookeeper-保姆級配置說明

一. 基本配置&#xff1a;clientPort&#xff1a; 客戶端連接的服務器所監聽的tcp端口&#xff0c;默認2181dataDir&#xff1a;內存數據庫保存的數據路徑。myid也存放在這個目錄下&#xff0c;數據以異步方式寫入。dataLogDir&#xff1a;事務日志存放路徑。服務在確認一個事務…

半小時打造七夕傳統文化網站:Qoder AI編程實戰記錄

背景 最近七夕到了&#xff0c;恰逢Qoder上線&#xff0c;萌生了一個想法&#xff0c;寫一個以中國傳統七夕為主題的網站。 七夕中國傳統情人節 Qoder 介紹 Qoder 是阿里巴巴推出的一款旨在提升開發效率的 AI 編程平臺。它通過上下文工程技術和智能體輔助&#xff0c;幫助開…

常見的 Loader 和 Plugin?

Loader: babel-loader&#xff1a;將ES6的代碼轉換成ES5的代碼。css-loader&#xff1a;解析CSS文件&#xff0c;并處理CSS中的依賴關系。style-loader&#xff1a;將CSS代碼注入到HTML文檔中。file-loader&#xff1a;解析文件路徑&#xff0c;將文件賦值到輸出目錄&#xff0…

設計模式學習筆記-----抽象策略模式

抽象策略模式由五個核心組件組成策略接口定義所有策略的統一規范&#xff0c;是策略模式的 "契約"mark()&#xff1a;策略的唯一標識&#xff08;類似字典的 key&#xff09;&#xff0c;默認返回 null&#xff0c;需具體策略實現類重寫&#xff08;如InterviewSubje…

RabbitMQ面試精講 Day 30:RabbitMQ面試真題解析與答題技巧

【RabbitMQ面試精講 Day 30】RabbitMQ面試真題解析與答題技巧 開篇&#xff1a;系列收官之作&#xff0c;直擊面試核心 今天是“RabbitMQ面試精講”系列的第30天&#xff0c;也是本系列的收官之作。經過前29天對RabbitMQ核心概念、高級特性、集群架構、性能調優與開發運維的系…

Coze Studio開源版:AI Agent開發平臺的深度技術解析- 入門篇

Coze Studio開源版&#xff1a;AI Agent開發平臺的深度技術解析 引言 在人工智能快速發展的今天&#xff0c;AI Agent&#xff08;智能體&#xff09;已成為連接大語言模型與實際應用場景的重要橋梁。然而&#xff0c;構建一個功能完整、性能穩定的AI Agent開發平臺并非易事&am…

一文了解 DeepSeek 系列模型的演進與創新

近年來&#xff0c;DeepSeek 團隊在大語言模型&#xff08;LLM&#xff09;領域持續發力&#xff0c;圍繞模型架構、專家路由、推理效率、訓練方法等方面不斷優化&#xff0c;推出了一系列性能強勁的開源模型。本文對 DeepSeek 系列的關鍵論文進行了梳理&#xff0c;幫助大家快…

開源大模型本地部署

一、大模型 T5\BERT\GPT → Transformer的兒子→自注意力機制神經網絡 大模型&#xff0c; Large Model&#xff0c;是指參數規模龐大、訓練數據量巨大、具有強泛化能力的人工智能模型&#xff0c;典型代表如GPT、BERT、PaLM等。它們通常基于深度神經網絡&#xff0c;特別是T…

DAY 57 經典時序預測模型1

知識點回顧 序列數據的處理&#xff1a; 處理非平穩性&#xff1a;n階差分處理季節性&#xff1a;季節性差分自回歸性無需處理 模型的選擇 AR(p) 自回歸模型&#xff1a;當前值受到過去p個值的影響MA(q) 移動平均模型&#xff1a;當前值收到短期沖擊的影響&#xff0c;且沖擊影…

貪吃蛇游戲(純HTML)

一、游戲截圖二、源碼 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>離譜貪吃蛇</title>…

InnoDB詳解2

InnoDB詳解2一.行結構1.結構圖2.InnoDB支持的數據行格式1&#xff09;查看當前數據庫或表的行格式2&#xff09;指定行格式3&#xff09;DYNAMIC 格式的組成3.數據區存儲真實數據方式4.行的額外(管理)信息區5.頭信息區域1&#xff09;刪除一行記錄時在InnoDB內部執行的操作6.Nu…

Rust系統編程實戰:駕馭內存安全、無畏并發與WASM跨平臺開發

簡介本文深入探討Rust在系統編程領域的核心實戰應用&#xff0c;通過代碼示例解析其所有權機制如何保障內存安全&#xff0c;如何利用 fearless concurrency 構建高性能并發應用&#xff0c;并實踐如何將Rust代碼編譯為WebAssembly&#xff08;WASM&#xff09;以突破性能瓶頸。…

JavaScript 基礎入門:從概念解析到流程控制

文章目錄1. JavaScript 核心認知1.1 瀏覽器與 JavaScript 的關系1.2 JavaScript 的三大核心組成1.3 JavaScript 引入1.3.1 內聯腳本&#xff08;事件屬性綁定&#xff09;1.3.2 內部腳本&#xff08;<script> 標簽嵌入&#xff09;1.3.3 外部腳本&#xff08;獨立 .js 文…