UNet 改進(5):結合SE模塊提升圖像分割性能

U-Net是醫學圖像分割領域最成功的架構之一,其對稱的編碼器-解碼器結構和跳躍連接使其能夠有效捕捉多尺度特征。本文將解析一個改進版的U-Net實現,該版本通過引入Squeeze-and-Excitation(SE)模塊進一步提升了模型性能。

一、架構概覽

這個改進的U-Net保持了經典U-Net的核心結構,但在每個卷積塊后添加了SE模塊,主要包含以下幾個關鍵組件:

  1. SE注意力模塊:增強重要通道的特征響應

  2. 雙卷積塊:基礎特征提取單元

  3. 編碼器-解碼器結構:逐步下采樣和上采樣

  4. 跳躍連接:結合低層和高層特征

二、核心組件詳解

1. SE注意力模塊 (SELayer)

class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())

SE模塊通過以下步驟工作:

  1. 使用全局平均池化將空間信息壓縮為一個通道描述符

  2. 通過兩個全連接層學習通道間的依賴關系

  3. 使用Sigmoid激活生成通道權重

  4. 將權重應用于原始特征圖

這種機制讓模型能夠自適應地強調重要特征通道,抑制不重要的通道。

2. 改進的雙卷積塊 (DoubleConv)

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels)  # 添加 SE 模塊)

每個雙卷積塊包含:

  • 兩個3×3卷積層,保持空間分辨率(padding=1)

  • 每個卷積后接批量歸一化和ReLU激活

  • 最后添加SE模塊進行通道注意力加權

3. 完整的改進U-Net (ImprovedUNet)

編碼器部分通過最大池化逐步下采樣,解碼器部分通過轉置卷積上采樣,并結合跳躍連接:

class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):# 初始化各層...def forward(self, x):# 編碼過程x1 = self.inc(x)       # 初始卷積x2 = self.down1(x1)    # 下采樣1x3 = self.down2(x2)    # 下采樣2x4 = self.down3(x3)    # 下采樣3x5 = self.down4(x4)    # 下采樣4# 解碼過程x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:])], dim=1))# ...類似處理其他上采樣層return self.outc(x)

三、創新點與優勢

  1. SE模塊集成:在每個雙卷積塊后添加SE模塊,使模型能夠自適應地重新校準通道特征響應

  2. 改進的特征融合:使用雙線性插值調整跳躍連接特征圖尺寸,確保精確對齊

  3. 參數效率:通過factor參數控制解碼器通道數,平衡模型容量和計算成本

四、性能分析

這個改進版U-Net相比原始U-Net有以下潛在優勢:

  • 更好的特征選擇能力,通過SE模塊突出重要特征

  • 更穩定的訓練,得益于批量歸一化的廣泛使用

  • 更精確的邊界預測,得益于改進的特征融合方式

五、使用示例

# 創建模型實例
model = ImprovedUNet(n_channels=3, n_classes=1)# 隨機輸入測試
input_tensor = torch.randn(2, 3, 256, 256)  # 2張256x256的RGB圖像
output = model(input_tensor)  # 輸出形狀為[2, 1, 256, 256]

六、完整代碼

import torch
import torch.nn as nn
import torch.nn.functional as F# SE 模塊
class SELayer(nn.Module):def __init__(self, in_channels, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# 改進的卷積塊
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),SELayer(out_channels)  # 添加 SE 模塊)def forward(self, x):return self.double_conv(x)# 改進的 U-Net 模型
class ImprovedUNet(nn.Module):def __init__(self, n_channels, n_classes):super().__init__()self.n_channels = n_channelsself.n_classes = n_classesself.inc = DoubleConv(n_channels, 64)self.down1 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(64, 128))self.down2 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(128, 256))self.down3 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(256, 512))factor = 2self.down4 = nn.Sequential(nn.MaxPool2d(2),DoubleConv(512, 1024 // factor))self.up1 = nn.Sequential(nn.ConvTranspose2d(1024 // factor, 512 // factor, kernel_size=2, stride=2))self.double_conv_up1 = DoubleConv(512 // factor + 512, 512 // factor)self.up2 = nn.Sequential(nn.ConvTranspose2d(512 // factor, 256 // factor, kernel_size=2, stride=2))self.double_conv_up2 = DoubleConv(256 // factor + 256, 256 // factor)self.up3 = nn.Sequential(nn.ConvTranspose2d(256 // factor, 128 // factor, kernel_size=2, stride=2))self.double_conv_up3 = DoubleConv(128 // factor + 128, 128 // factor)self.up4 = nn.Sequential(nn.ConvTranspose2d(128 // factor, 64, kernel_size=2, stride=2))self.double_conv_up4 = DoubleConv(64 + 64, 64)self.outc = nn.Conv2d(64, n_classes, kernel_size=1)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5)x = self.double_conv_up1(torch.cat([x, F.interpolate(x4, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up2(x)x = self.double_conv_up2(torch.cat([x, F.interpolate(x3, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up3(x)x = self.double_conv_up3(torch.cat([x, F.interpolate(x2, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))x = self.up4(x)x = self.double_conv_up4(torch.cat([x, F.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=True)], dim=1))logits = self.outc(x)return logits# 創建改進的 U-Net 模型實例
model = ImprovedUNet(n_channels=3, n_classes=1)
print(model)# 生成一個隨機輸入
input_tensor = torch.randn(2, 3, 256, 256)# 前向傳播
output = model(input_tensor)
print(output.shape)

七、適用場景

這種改進的U-Net特別適合以下任務:

  • 醫學圖像分割(CT/MRI)

  • 遙感圖像解析

  • 任何需要精確邊界預測的密集預測任務

八、總結

通過在U-Net中集成SE模塊,我們獲得了能夠自適應關注重要特征的改進架構。這種設計在不顯著增加計算成本的情況下,提高了模型的特征選擇能力,使其在各種圖像分割任務中表現更加出色。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76068.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76068.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76068.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

機器人擰螺絲緊固裝配(Robot screw fastening assembly)

機器人擰螺絲緊固裝配技術正以其高精度、高效率和高靈活性,重塑著傳統制造業的生產范式。這項融合了機械臂定位、扭矩控制、視覺引導與數據分析的自動化解決方案,不僅將工人從重復性高強度勞動中解放出來,更通過實時數據反饋與精準執行&#…

圖像處理中的 Gaussina Blur 和 SIFT 算法

Gaussina Blur 高斯模糊 高斯模糊的數學定義 高斯模糊是通過 高斯核(Gaussian Kernel) 對圖像進行卷積操作實現的. 二維高斯函數定義為 G ( x , y , σ ) 1 2 π σ 2 e ? x 2 y 2 2 σ 2 G(x, y, \sigma) \frac{1}{2\pi \sigma^2} e^{-\frac{x^2 y^2}{2\sigma^2}} G(x…

在Unity中實現《幽靈行者》風格的跑酷動作

基礎設置 角色控制器選擇: 使用Character Controller組件或Rigidbody Capsule Collider 推薦使用Character Controller以獲得更精確的運動控制 輸入系統: 使用Unity的新輸入系統(Input System Package)處理玩家輸入 滑鏟實現 public class Slide…

青蛙吃蟲--dp

1.dp數組有關元素--路長和次數 2.遞推公式 3.遍歷順序--最終影響的是路長&#xff0c;在外面 其次次數遍歷&#xff0c;即這次路長所有情況都更新 最后&#xff0c;遍歷次數自然就要遍歷跳長 4.max時時更新 dp版本 #include<bits/stdc.h> using namespace std; #def…

Tiktok 關鍵字 視頻及評論信息爬蟲(2) [2025.04.07]

&#x1f64b;?♀?Tiktok APP的基于關鍵字檢索的視頻及評論信息爬蟲共分為兩期&#xff0c;希望對大家有所幫助。 第一期&#xff1a;基于關鍵字檢索的視頻信息爬取 第二期見下文。 1.Node.js環境配置 首先配置 JavaScript 運行環境&#xff08;如 Node.js&#xff09;&…

Matlab繪圖—‘‘錯誤使用 plot輸入參數的數目不足‘‘

原因1&#xff1a; ?? 文件列名不是合法變量名 在excel中數據列名稱為Sample:float,將:刪除就解決了

Kotlin問題匯總

Kotlin問題匯總 真機安裝調試 查看真機的Android版本&#xff0c;將build.gradle文件中的minSdk改為手機的Android版本&#xff0c;點Sync Now更新設置 apk安裝失敗 在gradle.properties全局配置中設置android.injected.testOnlyfalse Unresolved reference: 在activity_…

基于VMware的Cent OS Stream 8安裝與配置及遠程連接軟件的介紹

1.VMware Workstation 簡介&#xff1a; VMware Workstation&#xff08;中文名“威睿工作站”&#xff09;是一款功能強大的桌面虛擬計算機軟件&#xff0c;提供用戶可在單一的桌面上同時運行不同的操作系統&#xff0c;和進行開發、測試 、部署新的應用程序的最佳解決方案。…

Go語言從零構建SQL數據庫(4)-解析器

SQL解析器&#xff1a;數據庫的"翻譯官"圖解與代碼詳解 圖解SQL解析過程 SQL解析器就像是人類語言與計算機之間的翻譯官&#xff0c;將我們書寫的SQL語句轉換成數據庫能夠理解和執行的結構。 #mermaid-svg-f9gAqHutDLL4McGy {font-family:"trebuchet ms"…

十道海量數據處理面試題與十個方法總結

一、十道海量數據處理面試題 ??1、海量日志數據&#xff0c;提取出某日訪問百度次數最多的那個IP。(分治思想 哈希表) 首先&#xff0c;從日志中提取出所有訪問百度的IP地址&#xff0c;將它們逐個寫入一個大文件中&#xff0c;便于后續處理。 考慮到IP地址是32位的&#…

SolidWorks2025三維計算機輔助設計(3D CAD)軟件超詳細圖文安裝教程(2025最新版保姆級教程)

目錄 前言 一、SolidWorks下載 二、SolidWorks安裝 三、啟動SolidWorks 前言 SolidWorks 是一款由法國達索系統&#xff08;Dassault Systmes&#xff09;公司開發的三維計算機輔助設計&#xff08;3D CAD&#xff09;軟件&#xff0c;廣泛用于機械設計、工程仿真和產品開…

IntelliJ IDEA 2020~2024 創建SpringBoot項目編輯報錯: 程序包org.springframework.boot不存在

目錄 前奏解決結尾 前奏 哈&#xff01;今天在處理我的SpringBoot項目時&#xff0c;突然遇到了一些讓人摸不著頭腦的錯誤提示&#xff1a; java: 程序包org.junit不存在 java: 程序包org.junit.runner不存在 java: 程序包org.springframework.boot.test.context不存在 java:…

CPU 壓力測試命令大全

CPU 壓力測試命令大全 以下是 Linux/Unix 系統下常用的 CPU 壓力測試命令和工具&#xff0c;可用于測試 CPU 性能、穩定性和散熱能力。 1. 基本壓力測試命令 1.1 使用 yes 命令 yes > /dev/null & # 啟動一個無限循環進程 yes > /dev/null & # 啟動第二個進…

#SVA語法滴水穿石# (003)關于 sequence 和 property 的區別和聯系

在 SystemVerilog Assertions (SVA) 中,sequence 和 property 是兩個核心概念,它們既有區別又緊密相關。對于初學者,可能不需要過多理解;但是要想寫出復雜精美的斷言,深刻理解兩者十分重要。今天,我們匯總和學習一下該知識點。 1. 區別 特性sequenceproperty定義描述一系…

WordPress浮動廣告插件+飄動效果客服插件

源碼介紹 WordPress浮動廣告插件飄動效果客服插件 將源碼上傳到wordpress的插件根目錄下&#xff0c;解壓&#xff0c;然后后臺啟用即可 截圖 源碼免費獲取 WordPress浮動廣告插件飄動效果客服插件

虛幻基礎:藍圖基礎知識

文章目錄 組件藍圖創建時&#xff0c;優先創建組件&#xff0c;如c一樣。 UI控件控件不會自動創建&#xff0c;而是在藍圖創建函數中手動創建。 函數內使用S序列接退出&#xff0c;并不會等所有執行完再退出&#xff0c;而是一個執行完后直接退出 組件 藍圖創建時&#xff0c;…

《AI大模型應知應會100篇》加餐篇:LlamaIndex 與 LangChain 的無縫集成

加餐篇&#xff1a;LlamaIndex 與 LangChain 的無縫集成 問題背景&#xff1a;在實際應用中&#xff0c;開發者常常需要結合多個框架的優勢。例如&#xff0c;使用 LangChain 管理復雜的業務邏輯鏈&#xff0c;同時利用 LlamaIndex 的高效索引和檢索能力構建知識庫。本文在基于…

深度學習項目--分組卷積與ResNext網絡實驗探究(pytorch復現)

&#x1f368; 本文為&#x1f517;365天深度學習訓練營 中的學習記錄博客&#x1f356; 原作者&#xff1a;K同學啊 前言 ResNext是分組卷積的開始之作&#xff0c;這里本文將學習ResNext網絡&#xff1b;本文復現了ResNext50神經網絡&#xff0c;并用其進行了猴痘病分類實驗…

從代碼學習深度學習 - RNN PyTorch版

文章目錄 前言一、數據預處理二、輔助訓練工具函數三、繪圖工具函數四、模型定義五、模型訓練與預測六、實例化模型并訓練訓練結果可視化總結前言 循環神經網絡(RNN)是深度學習中處理序列數據的重要模型,尤其在自然語言處理和時間序列分析中有著廣泛應用。本篇博客將通過一…

JS DOM節點增刪改查

增加節點 通過document.createNode()函數創建對象 // 創建節點 const div document.createElement(div) // 追加節點 document.body.appendChild(div) 克隆節點 刪除節點