VGG改進(8):融合Self-Attention的CNN架構

1. 自注意力機制簡介

自注意力機制是Transformer架構的核心組件,它能夠計算輸入序列中每個元素與其他所有元素的相關性。與CNN的局部感受野不同,自注意力機制允許模型直接建立遠距離依賴關系,從而捕獲全局上下文信息。

在計算機視覺中,這意味著模型不僅能夠關注圖像的局部特征(如邊緣、紋理),還能理解這些特征在全局范圍內的相互關系。這種能力對于復雜視覺任務(如場景理解、細粒度分類)尤為重要。

2. VGG16架構回顧

VGG16由牛津大學視覺幾何組提出,其核心特點是使用小尺寸卷積核(3×3)構建深度網絡。網絡包含5個卷積塊,每個塊后接最大池化層進行下采樣,最后通過三個全連接層完成分類。

VGG16的優勢在于其簡潔性和有效性,但局限性也很明顯:卷積操作的局部性限制了模型捕獲長距離依賴的能力,而全連接層的參數量過大容易導致過擬合。

3. 自注意力與CNN的融合策略

將自注意力機制引入CNN有多種方式,本文實現的是一種局部-全局特征融合策略:在CNN提取局部特征后,通過自注意力機制增強這些特征的全局上下文信息。

具體來說,我們在VGG16的特定卷積塊后插入Transformer編碼器層,使模型能夠在不同抽象層次上融合全局信息。這種設計有以下優勢:

  1. 多尺度特征增強:在不同深度的卷積層后添加注意力,可以捕獲從低級到高級的多尺度全局信息

  2. 計算效率:僅在選定位置添加注意力模塊,平衡了性能與計算開銷

  3. 架構靈活性:可以選擇在不同深度添加注意力,適應不同任務的需求

4. 代碼實現解析

4.1 自注意力機制實現

class SelfAttention(nn.Module):"""標準Transformer自注意力機制"""def __init__(self, embed_dim, num_heads=8, dropout=0.1):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)

自注意力模塊首先通過線性變換生成查詢(Query)、鍵(Key)和值(Value)三個矩陣,然后將輸入分割成多個頭進行并行計算,最后將結果合并并通過輸出投影層。

4.2 Transformer編碼器層

class TransformerEncoderLayer(nn.Module):"""Transformer編碼器層"""def __init__(self, embed_dim, num_heads=8, dropout=0.1, expansion_factor=4):super(TransformerEncoderLayer, self).__init__()self.self_attn = SelfAttention(embed_dim, num_heads, dropout)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim * expansion_factor),nn.ReLU(inplace=True),nn.Dropout(dropout),nn.Linear(embed_dim * expansion_factor, embed_dim),nn.Dropout(dropout))

編碼器層遵循標準Transformer結構,包含一個自注意力子層和一個前饋神經網絡子層,每個子層都使用殘差連接和層歸一化。

4.3 VGG16與注意力的融合

class VGG16WithAttention(nn.Module):def __init__(self, num_classes=1000, attention_positions=[3, 4]):super(VGG16WithAttention, self).__init__()# 卷積特征提取層self.features = nn.Sequential(...)self.attention_positions = attention_positionsself.attention_layers = nn.ModuleDict()# 在指定位置添加注意力層if 3 in attention_positions:self.attention_layers['block3'] = TransformerEncoderLayer(256)if 4 in attention_positions:self.attention_layers['block4'] = TransformerEncoderLayer(512)if 5 in attention_positions:self.attention_layers['block5'] = TransformerEncoderLayer(512)

在VGG16WithAttention類中,我們保留了原始VGG16的特征提取層,并在指定位置添加了Transformer編碼器層。用戶可以通過attention_positions參數靈活選擇在哪些卷積塊后添加注意力機制。

4.4 前向傳播過程

def forward(self, x):features = []# 逐層處理特征for i, layer in enumerate(self.features):x = layer(x)# 在特定卷積塊后應用注意力if i == 14 and 3 in self.attention_positions:  # 第三卷積塊結束x = self._apply_attention(x, 'block3')elif i == 21 and 4 in self.attention_positions:  # 第四卷積塊結束x = self._apply_attention(x, 'block4')elif i == 28 and 5 in self.attention_positions:  # 第五卷積塊結束x = self._apply_attention(x, 'block5')

在前向傳播過程中,模型首先通過卷積層提取特征,然后在指定位置將特征圖重塑為序列形式,應用自注意力機制,最后恢復為原始形狀繼續傳播。

5. 注意力應用的技術細節

5.1 特征圖序列化

將2D特征圖轉換為序列是應用自注意力的關鍵步驟:

def _apply_attention(self, x, block_name):"""應用自注意力機制"""batch_size, channels, height, width = x.size()# 將特征圖重塑為序列形式 [batch_size, seq_len, embed_dim]x_reshaped = x.view(batch_size, channels, -1).transpose(1, 2)# 應用注意力attended = self.attention_layers[block_name](x_reshaped)# 恢復原始形狀attended = attended.transpose(1, 2).view(batch_size, channels, height, width)return attended

這里,我們將空間維度(高度×寬度)展平為序列長度,通道維度作為嵌入維度。這種處理方式允許自注意力機制在空間維度上建立全局依賴關系。

5.2 位置編碼的考慮

值得注意的是,本文實現的版本沒有顯式添加位置編碼。在標準Transformer中,位置編碼用于提供序列中元素的位置信息。對于圖像任務,位置信息至關重要,因為像素間的空間關系具有重要含義。

在實際應用中,可以考慮添加以下類型的位置編碼:

  1. 可學習的位置編碼:隨機初始化并通過訓練學習

  2. 正弦位置編碼:使用不同頻率的正弦和余弦函數

  3. 相對位置編碼:編碼元素間的相對位置而非絕對位置

6. 模型優勢與應用場景

6.1 優勢分析

  1. 全局上下文建模:自注意力機制使模型能夠捕獲長距離依賴,理解圖像全局結構

  2. 多尺度特征融合:在不同深度添加注意力,實現了多尺度特征的全局融合

  3. 架構靈活性:可以選擇性地在不同階段添加注意力,平衡性能與計算開銷

  4. 即插即用:注意力模塊可以輕松集成到現有CNN架構中,無需大幅修改

6.2 應用場景

這種混合架構特別適合以下計算機視覺任務:

  1. 細粒度圖像分類:需要捕獲細微特征差異和全局上下文關系

  2. 場景理解:需要理解場景中多個對象的空間和語義關系

  3. 圖像分割:全局上下文信息有助于提高邊界準確性和語義一致性

  4. 目標檢測:注意力機制可以幫助模型關注相關區域,提高檢測精度

7. 實驗與性能分析

為了驗證融合注意力的VGG16的性能,我們在多個數據集上進行了實驗。與原始VGG16相比,融合模型在以下方面表現出優勢:

  1. 分類準確率:在ImageNet等復雜數據集上,準確率有顯著提升

  2. 收斂速度:注意力機制有助于梯度傳播,加速模型收斂

  3. 魯棒性:對遮擋、旋轉等干擾因素表現出更好的魯棒性

然而,注意力機制也帶來了一定的計算開銷,參數量和計算量都有所增加。在實際應用中需要根據任務需求和資源約束進行權衡。

8. 擴展與變體

本文介紹的基礎架構可以進一步擴展:

  1. 多頭注意力:使用多個注意力頭捕獲不同類型的依賴關系

  2. 跨尺度注意力:在不同尺度的特征圖間應用注意力機制

  3. 高效注意力:使用線性注意力、局部注意力等變體降低計算復雜度

  4. 預訓練與微調:在大規模數據集上預訓練后遷移到特定任務

9. 實踐建議

對于希望在實際項目中應用此架構的研究人員和工程師,以下建議可能有所幫助:

  1. 注意力位置選擇:淺層注意力捕獲空間關系,深層注意力捕獲語義關系

  2. 計算資源權衡:在計算資源受限時,可以選擇性添加注意力或使用高效變體

  3. 逐步集成:先從單個注意力層開始,逐步增加復雜度

  4. 可視化分析:使用注意力可視化工具理解模型關注區域

完整代碼

如下:

import torch
import torch.nn as nn
import mathclass SelfAttention(nn.Module):"""標準Transformer自注意力機制"""def __init__(self, embed_dim, num_heads=8, dropout=0.1):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x):batch_size, seq_len, embed_dim = x.size()# 生成Q, K, Vqkv = self.qkv_proj(x).chunk(3, dim=-1)q, k, v = [part.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) for part in qkv]# 計算注意力分數scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)attn_weights = torch.softmax(scores, dim=-1)attn_weights = self.dropout(attn_weights)# 應用注意力權重attn_output = torch.matmul(attn_weights, v)attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)# 輸出投影output = self.out_proj(attn_output)return outputclass TransformerEncoderLayer(nn.Module):"""Transformer編碼器層"""def __init__(self, embed_dim, num_heads=8, dropout=0.1, expansion_factor=4):super(TransformerEncoderLayer, self).__init__()self.self_attn = SelfAttention(embed_dim, num_heads, dropout)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim * expansion_factor),nn.ReLU(inplace=True),nn.Dropout(dropout),nn.Linear(embed_dim * expansion_factor, embed_dim),nn.Dropout(dropout))def forward(self, x):# 自注意力子層attn_output = self.self_attn(x)x = self.norm1(x + attn_output)# 前饋網絡子層ffn_output = self.ffn(x)x = self.norm2(x + ffn_output)return xclass VGG16WithAttention(nn.Module):def __init__(self, num_classes=1000, attention_positions=[3, 4]):"""Args:num_classes: 分類數量attention_positions: 在哪些卷積塊后添加注意力機制 (1-5)"""super(VGG16WithAttention, self).__init__()# 卷積特征提取層self.features = nn.Sequential(# 第一層卷積塊nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第二層卷積塊nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第三層卷積塊nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第四層卷積塊nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第五層卷積塊nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.attention_positions = attention_positionsself.attention_layers = nn.ModuleDict()# 在指定位置添加注意力層if 3 in attention_positions:self.attention_layers['block3'] = TransformerEncoderLayer(256)if 4 in attention_positions:self.attention_layers['block4'] = TransformerEncoderLayer(512)if 5 in attention_positions:self.attention_layers['block5'] = TransformerEncoderLayer(512)self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes),)def _apply_attention(self, x, block_name):"""應用自注意力機制"""batch_size, channels, height, width = x.size()# 將特征圖重塑為序列形式 [batch_size, seq_len, embed_dim]x_reshaped = x.view(batch_size, channels, -1).transpose(1, 2)# 應用注意力attended = self.attention_layers[block_name](x_reshaped)# 恢復原始形狀attended = attended.transpose(1, 2).view(batch_size, channels, height, width)return attendeddef forward(self, x):features = []# 逐層處理特征for i, layer in enumerate(self.features):x = layer(x)# 在特定卷積塊后應用注意力if i == 14 and 3 in self.attention_positions:  # 第三卷積塊結束x = self._apply_attention(x, 'block3')elif i == 21 and 4 in self.attention_positions:  # 第四卷積塊結束x = self._apply_attention(x, 'block4')elif i == 28 and 5 in self.attention_positions:  # 第五卷積塊結束x = self._apply_attention(x, 'block5')x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 創建帶注意力的VGG模型
def vgg16_with_attention(num_classes=1000, attention_positions=[3, 4]):model = VGG16WithAttention(num_classes=num_classes, attention_positions=attention_positions)return model# 示例使用
if __name__ == "__main__":model = vgg16_with_attention(num_classes=1000, attention_positions=[3, 4, 5])# 測試前向傳播dummy_input = torch.randn(2, 3, 224, 224)output = model(dummy_input)print(f"Output shape: {output.shape}")print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/96298.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/96298.shtml
英文地址,請注明出處:http://en.pswp.cn/web/96298.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ES6 面試題及詳細答案 80題 (33-40)-- Symbol與集合數據結構

《前后端面試題》專欄集合了前后端各個知識模塊的面試題,包括html,javascript,css,vue,react,java,Openlayers,leaflet,cesium,mapboxGL,threejs&…

PG-210-HI 山洪預警系統呼叫端:筑牢山區應急預警 “安全防線”

在山洪災害多發的山區,及時、準確的預警信息傳遞是保障群眾生命財產安全的關鍵。由 PG-210-HI 型號構成的山洪預警系統呼叫端主機,憑借其全面的功能、先進的特性與可靠的性能,成為連接管理員與群眾的重要應急樞紐,為山區構建起一道…

研學旅游產品設計實訓室:賦能產品落地,培養實用人才

1. 研學旅游產品設計實訓室的定位與功能 研學旅游產品設計實訓室是專門為學生提供研學課程與產品開發、模擬設計、項目推演、成果展示等實踐活動的教學空間。該實訓室應支持以下功能: 研學主題設計與目標制定; 課程內容與學習方法的選擇與整合&#xf…

4215kg輕型載貨汽車變速器設計cad+設計說明書

第一章 前言 3 1.1 變速器的發展環繞現狀 3 1.2 本次設計目的和意義 4 第二章 傳動機構布置方案分析及設計 5 2.1 傳動機構結構分析與類型選擇 5 2.2變速器主傳動方案的選擇 5 2.3 倒檔傳動方案 6 2..4 變速器零、部件結構方案設計 6 2.4.1 齒輪形式 …

9月10日

TCP客戶端代碼#include<myhead.h> #define SER_IP "192.168.108.179" //服務器&#xff49;&#xff50;地址 #define SER_PORT 8888 //服務器端口號 #define CLI_IP "192.168.108.239" //客戶端&#xff49;&#xff50;地址 …

案例開發 - 日程管理 - 第七期

項目改造&#xff0c;進入 demo-schedule 項目中&#xff0c;下載 pinia 依賴在 main.js 中開啟 piniaimport { createApp } from vue import App from ./App.vue import router from ./router/router.js import {createPinia} from pinialet pinia createPinia() const app …

infinityfree 網頁連接內網穿透 localtunnel會換 還是用frp成功了

模型庫首頁 魔搭社區 fatedier/frp: A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet. 我嘗試用本機ipv6&#xff0c;失敗了 配置文件 - ChmlFrp 香港2才能用 只支持https CNAME解析 | 怊貓科技 | 文檔 How to create …

批量更新數據:Mybatis update foreach 和 update case when 寫法及比較

在平常的開發工作中&#xff0c;我們經常需要批量更新數據&#xff0c;業務需要每次批量更新幾千條數據&#xff0c;采用 update foreach 寫法的時候&#xff0c;接口響應 10s 左右&#xff0c;優化后&#xff0c;采用 update ... case when 寫法&#xff0c;接口響應 2s 左右。…

Java基礎篇04:數組、二維數組

1 數組 數組是一個數據容器&#xff0c;可用來存儲一批同類型的數據。 1.1 數組的定義方式 靜態初始化 數據類型[][] 數組名 {元素1&#xff0c;元素2&#xff0c;元素3}; string[][] name {"wfs","jsc","qf"} 動態初始化 數據類型[][] 數組名…

unity開發類似個人網站空間

可以用 Unity 開發 “個人網站空間” 類工具&#xff0c;但需要結合其技術特性和適用場景來判斷是否合適。以下從技術可行性、優勢、局限性、適用場景四個方面具體分析&#xff1a;一、技術可行性Unity 本質是游戲引擎&#xff0c;但具備開發 “桌面應用” 和 “交互內容” 的能…

SDK游戲盾如何實現動態加密

SDK游戲盾的動態加密體系通過??密鑰動態管理、多層加密架構、協議混淆、AI自適應調整及設備綁定??等多重機制協同作用&#xff0c;實現對游戲數據全生命周期的動態保護&#xff0c;有效抵御中間人攻擊、協議破解、重放攻擊等威脅。以下從核心技術與實現邏輯展開詳細說明&am…

TensorFlow平臺介紹

什么是 TensorFlow&#xff1f; TensorFlow 是一個由 Google Brain 團隊 開發并維護的 開源、端到端機器學習平臺。它的核心是一個強大的數值計算庫&#xff0c;特別擅長于使用數據流圖來表達復雜的計算任務&#xff0c;尤其適合大規模機器學習和深度學習模型的構建、訓練和部署…

TENGJUN防水TYPE-C連接器:立貼結構與IPX7防護的精密融合

在戶外電子、智能家居、車載設備等對連接可靠性與空間適配性要求嚴苛的場景中&#xff0c;連接器不僅是信號與電力傳輸的“橋梁”&#xff0c;更需抵御潮濕、粉塵等復雜環境的侵蝕。TENGJUN防水TYPE-C連接器以“雙排立貼”為核心設計&#xff0c;融合鋅合金底座、精準尺寸控制與…

Spring Boot + Vue 項目中使用 Redis 分布式鎖案例

加鎖使用命令&#xff1a;set lock_key unique_value NX PX 1000NX:等同于SETNX &#xff0c;只有鍵不存在時才能設置成功PX&#xff1a;設置鍵的過期時間為10秒unique_value&#xff1a;一個必須是唯一的隨機值&#xff08;UUID&#xff09;&#xff0c;通常由客戶端生成…

微信小程序攜帶token跳轉h5, h5再返回微信小程序

需求: 在微信小程序內跳轉到h5, 瀏覽完后點擊返回按鈕再返回到微信小程序中 微信小程序跳轉h5: 微信小程序跳轉h5,這個還是比較簡單的, 但要注意細節 一、微信小程序代碼 1.新建跳轉h5頁面, 新建文件夾,新建page即可 2.使用web-view標簽 wxml頁面 js頁面 到此為止, 小程序…

【機器學習】通過tensorflow實現貓狗識別的深度學習進階之路

【機器學習】通過tensorflow實現貓狗識別的深度學習進階之路 簡介 貓狗識別作為計算機視覺領域的經典入門任務&#xff0c;不僅能幫助我們掌握深度學習的核心流程&#xff0c;更能直觀體會到不同優化策略對模型性能的影響。本文將從 “從零搭建簡單 CNN” 出發&#xff0c;逐步…

異步處理(前端面試)

Promise 1&#xff1a;使用promise原因 了解回調地獄【什么是回調地獄】 1&#xff1a;回調地獄是異步獲取結果后&#xff0c;為下一個異步函數提供參數&#xff0c;層層回調嵌入回調 2&#xff1a;導致回調層次很深&#xff0c;代碼維護特別困難 3&#xff1a;在沒有ES6時&…

3種XSS攻擊簡單案例

1、接收cookie端攻擊機上用python寫個接收web程序flask from flask import Flask, request, Responseapp Flask(__name__)app.route(/) def save_cookie():cookie request.args.get(cookie, )if cookie:with open(/root/cookies.txt, a) as f:f.write(f"{cookie}\n"…

Docker 部署生產環境可用的 MySQL 主從架構

簡介跨云服務器一主一從&#xff0c;可以自己按照邏輯配置多個從服務器 假設主服務器ip: 192.168.0.4 從服務器ip&#xff1a;192.168.0.5 系統 CentOS7.9 &#xff08;停止維護了&#xff0c;建議大家用 Ubuntu 之類的&#xff0c;我這個沒辦法&#xff0c;前人在云服務器上…

DeepResearch(上)

概述 OpenAI首先推出Deep Research Agent&#xff0c;深度研究智能體&#xff0c;簡稱DRA。 通過自主編排多步驟網絡探索、定向檢索和高階綜合&#xff0c;可將大量在線信息轉換為分析師級別的、引用豐富的報告&#xff0c;將數小時的手動桌面研究壓縮為幾分鐘。 作為新一代…