Pytorch模型復現筆記-STN(空間注意力Transformer網絡)講解+架構搭建(可直接copy運行)+ MNIST數據集視角調整實驗

Spatial Transformer Networks

本文了講述STN的基本架構,空間幾何注意力模塊的基本原理,冒煙測試以及STN在MNIST數據集用于模型自動調整圖片視角的實驗,如果大家有不懂或者發現了錯誤的地方,歡迎討論。

  • 中文名:空間Transformer網絡
  • 論文鏈接:Arxiv

我更傾向于叫它為Spatio Geometry Transformer Network, 因為它的注意力同時包括了是旋轉,平移,仿射等多種幾何變換,而不是單純地裁剪以注意空間里面的重點。

目錄

  • Spatial Transformer Networks
    • 模型簡介
    • 提出背景
    • 設計思路
    • 達到效果及優勢
    • 對后續模型的影響
    • 網絡結構
    • Pytorch模型實現+MNIST數據集視角調整實驗
      • 準備庫、數據集、數據加載器
      • 定義網絡結構
      • 定義訓練/測試過程
      • 可視化效果

模型簡介

  • 作者:Ghassen HAMROUNI
  • 發布年:2015
  • 為什么這么叫:因為它用使用了空間幾何注意力
  • 主要成就:第一個使用空間幾何注意力的卷積神經網絡

STN是對任何空間變化的推廣,其允許神經網絡學習如何對輸入圖像進行 空間變換,以增強模型的 幾何不變性。例如,其可以對感興趣的區域進行裁剪,或者縮放和矯正圖像的方向。這個機制對CNN很有用,因為其對旋轉、縮放、甚至更一般的放射變換并非不變。

在這里插入圖片描述

提出背景

之前有哪些相同目的的模型/方法?

在STN出現之前,主要是用純粹的CNN來對圖像進行特征圖的提取。但是由于CNN對于圖像的幾何變換的魯棒性不強,因此研究人員設計了幾種方法來改良CNN對于幾何變換的魯棒性:

  1. 數據增強(Data Augmentation):這是最簡單也最常用的方法。通過對訓練集中的圖片進行隨機旋轉、縮放、平移或裁剪等操作,來增加數據的多樣性。這使得模型在訓練時能接觸到更多不同幾何形態的樣本,從而提高其泛化能力
  2. 多尺度或者多角度訓練:直接在多個尺度或角度下對同一圖像進行訓練,迫使網絡學習對這些變換不變的特征。
  3. 使用手工設計的特征(如SIFT、ORB等): 在傳統的計算機視覺任務中,人們會使用這些對幾何變換具有一定不變性的特征來處理問題。例如,SIFT特征描述子在一定程度上對尺度和旋轉是不變的。

之前的模型/方法有什么不足?

  1. 數據增強的局限性: 數據增強雖然有效,但它是一種“靜態”的、預先定義好的方法。它并不能讓網絡自適應地學習應該對哪些變換進行處理。換句話說,模型是在被動地接受經過變換的數據,而不是主動地去“尋找”并“矯正”那些重要的區域。
  2. 計算效率低下:多尺度或多角度訓練會顯著增加計算量和內存消耗,因為需要為每個變換后的版本都進行一次前向傳播。
  3. 無法處理復雜的、非預期的變換: 數據增強通常只覆蓋簡單的變換,如旋轉和平移。對于更復雜的、特定于任務的“扭曲”或“不規則”變換,效果不佳。

設計思路

這個模型針對不足提出了什么改進方案?解決了什么問題?有什么人類直覺在里面?

鑒于之前方法的缺陷,作者從人類直覺的角度進行了思考,他認為,人類之所以能夠從一個偏移,旋轉,或者不同視角下的圖片中還原原本的圖片(比如把一不同視角下的5,無論仰視還是俯視都可以看出來),是因為我們腦袋中 “自帶一個用來進行動態幾何變換的機制” (可以理解為自帶一個“自適應”的幾何變換矩陣),我們能根據注意力自動調整這個矩陣的參數,把圖像校正到我們大腦中最容易理解和識別的“理想狀態”。
在這里插入圖片描述
以下是幾何變換矩陣的原理:
在這里插入圖片描述
如果對變換矩陣施加如下約束,那么這個矩陣則可以對原圖進行旋轉,平移,以及仿射變換。
在這里插入圖片描述

但是實際的代碼實現中并不會對其施加以上約束,因為模型可能通過學習學到更加高級的幾何變換,而不僅僅局限于以上三種變換。

達到效果及優勢

在這里插入圖片描述

  • 性能提升: 相比于其他沒有使用STN的模型(如 Cimpoi '15, Simon '15 等),使用了STN的CNN模型在CUB-200-2011鳥類分類數據集上的準確率有了明顯的提升。在高分辨率的圖片上性能提升更加顯著。
  • 可解釋的空間變換注意力: 這是STN最直觀也最令人興奮的優勢。圖表右側的圖片展示了2xST-CNN和4xST-CNN模型中STN模塊學習到的空間變換。論文作者在圖中特別指出,對于2xST-CNN模型,一個STN模塊(紅色框)學習定位和放大鳥的頭部,而另一個STN模塊(綠色框)則學習定位和放大鳥的身體。也就是每個STN模塊都注意到了不同的,但是對鳥的分類至關重要東西!
  • 即插即用的模塊:STN最大的優勢之一就是他能非常容易地插入任何現有的CNN,并且只需要很小的修改!

對后續模型的影響

  1. 開創“可學習的變換”思想
    STN首次將空間變換的能力作為可學習的模塊集成到神經網絡中。它證明了網絡可以自己決定如何對輸入數據進行幾何變換,而不是依賴于預先設定好的規則(如數據增強)。 這種思想被廣泛應用于各種需要處理非剛性、非線性變換的任務中。例如,在醫學圖像處理中,STN的思想被用來進行圖像配準(Image Registration),自動對齊不同時間或不同設備拍攝的病灶圖像。

  2. 空間注意力機制的先驅
    盡管STN的關注點是“幾何變換”,但它通過定位并變換最關鍵的區域,實際上實現了一種形式的注意力。它讓網絡將“注意力”集中在最關鍵的像素或特征上,并將其“擺正”以方便后續處理。它的成功啟發了后續的注意力機制(Attention Mechanism)研究。雖然STN是“空間注意力”,更嚴謹一定來說叫“空間幾何注意力”,但它證明了讓網絡“有選擇地”關注輸入中最重要的部分是提高性能的有效手段。這為后來更廣泛的通道注意力(Channel Attention)自注意力(Self-Attention)以及Transformer模型的興起奠定了基礎。

  3. “即插即用”模塊化設計的典范
    STN模塊可以輕松地插入到任何CNN架構中,這極大地降低了其應用門檻,并展示了模塊化設計在深度學習中的巨大潛力。* 這種設計理念被廣泛采納。如今,許多深度學習模型都由各種可插拔的模塊組成,比如SENet中的“通道注意力”模塊、ECA-Net中的“高效通道注意力”模塊等等。這些模塊都遵循了STN的“即插即用”設計思想,讓研究人員可以更容易地進行模型改進和創新。

總而言之,STN的貢獻遠不止于提高了一點點準確率。它引入的 “可學習變換”“空間注意力”“模塊化設計” 等核心思想,深刻地影響了后續的計算機視覺和深度學習研究,成為連接傳統CNN和現代Transformer模型的一個重要橋梁。

網絡結構

在這里插入圖片描述
Spatial Transformer 模組可以分解成如下三個關鍵組成部分:

  1. 定位網絡(localisation net):其中包括兩個全連接層,第一個層負責提取圖片中的基礎幾何信息,第二個層負責根據基礎幾何信息回歸出幾何變換矩陣
  2. 網格生成器(grid generator): 負責根據生成的變換矩陣生成變換網格,本質上是定義了圖像的變換
  3. 采樣器(Sampler):利用定義好的grid對原圖片進行變換

Pytorch模型實現+MNIST數據集視角調整實驗

準備庫、數據集、數據加載器

首先我們把庫和數據集,以及數據加載器準備好:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as npplt.ion()   # interactive modefrom six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Training dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)

定義網絡結構

接下來就是搭建我們帶有一個簡單Spatio Transformer 模塊的網絡了:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)# 修復:等號'='而不是'-'self.fc2 = nn.Linear(50, 10)self.flatten = nn.Flatten()# STN模塊組成-特征提取localization子網絡, 用來先提取能反應圖像幾何信息的高級特征# 修復:括號'()'而不是')'self.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.MaxPool2d(2, stride=2),nn.ReLU(True),nn.Conv2d(8, 10, kernel_size=5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# STN模塊組成-全連接定位網絡-回歸仿射矩陣self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32), nn.ReLU(True),# 直接把幾何變換矩陣全連接出來nn.Linear(32, 3 * 2))# 網絡權重初始化,保證一開始的變換矩陣從什么都不做的單位幀self.fc_loc[2].weight.data.zero_() # 定位網絡直接用0初始化權重self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) """[1, 0, 0][0, 1, 0]"""def stn(self, x):xs = self.localization(x)xs = self.flatten(xs)# 修復:xs_theta = self.fc_loc(xs)# 重塑成矩陣形狀theta = theta.view(-1, 2, 3)# 網格生成部分,對圖像引用幾何變換矩陣產生新圖像grid = F.affine_grid(theta, x.size(), align_corners=True)x = F.grid_sample(x, grid, align_corners=True)return xdef forward(self, x):x = self.stn(x)# 傳給分類層# Perform the usual forward passx = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = self.flatten(x)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)# 冒煙測試
# 修復:缺失 device 的定義
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
# 修復:缺失 .to(device)
input = torch.randn(1, 1, 28, 28).to(device)
output = model(input)
# 修復:輸出的維度是 [1, 10],所以應該使用 output.shape
print(output.shape) # 輸出torch.Size([1, 10])

定義訓練/測試過程

optimizer = optim.SGD(model.parameters(), lr=0.01)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 500 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure the STN performances on MNIST.
#def test():with torch.no_grad():model.eval()test_loss = 0correct = 0for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, size_average=False).item()# get the index of the max log-probabilitypred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

可視化效果

我們這里直接可視化原圖像在經過空間注意力模塊調整之后會變成啥樣:

def convert_image_np(inp):"""Convert a Tensor to numpy image."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)return inp# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.def visualize_stn():with torch.no_grad():# Get a batch of training datadata = next(iter(test_loader))[0].to(device)input_tensor = data.cpu()transformed_input_tensor = model.stn(data).cpu()in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))# Plot the results side-by-sidef, axarr = plt.subplots(1, 2)axarr[0].imshow(in_grid)axarr[0].set_title('Dataset Images')axarr[1].imshow(out_grid)axarr[1].set_title('Transformed Images')for epoch in range(1, 20 + 1):train(epoch)test()# Visualize the STN transformation on some input batch
visualize_stn()plt.ioff()
plt.show()

在這里插入圖片描述
最終的準確率在epoch20之后達到了99%,并且可以看到STN模塊確實對原圖進行了幾何變換,把圖像校正到了對模型,甚至對人類都更容易理解和識別的“理想狀態”。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/93788.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/93788.shtml
英文地址,請注明出處:http://en.pswp.cn/web/93788.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【LeetCode】16. 最接近的三數之和

文章目錄16. 最接近的三數之和題目描述示例 1:示例 2:提示:解題思路算法分析問題本質分析排序雙指針法詳解雙指針移動策略搜索過程可視化各種解法對比算法流程圖邊界情況處理時間復雜度分析空間復雜度分析關鍵優化點實際應用場景測試用例設計…

微信小程序實現藍牙開啟自動播放BGM

下面是一個完整的微信小程序實現方案,當藍牙設備連接時自動播放背景音樂(BGM)。實現思路監聽藍牙設備連接狀態當檢測到藍牙設備連接時,自動播放音樂當藍牙斷開時,停止音樂播放處理相關權限和用戶交互完整代碼實現1. 項目結構text/pages/index…

XML 序列化與操作詳解筆記

一、XML 基礎概念XML&#xff08;eXtensible Markup Language&#xff0c;可擴展標記語言&#xff09;是一種用于存儲和傳輸數據的標記語言&#xff0c;由 W3C 制定&#xff0c;具有以下特點&#xff1a;可擴展性&#xff1a;允許自定義標記&#xff08;如<Student>、<…

第八十四章:實戰篇:圖 → 視頻:基于 AnimateDiff 的視頻合成鏈路——讓你的圖片“活”起來,瞬間擁有“電影感”!

AI圖生視頻前言&#xff1a;從“剎那永恒”到“動態大片”——AnimateDiff&#xff0c;讓圖片“活”起來&#xff01;第一章&#xff1a;痛點直擊——靜態圖像到視頻&#xff0c;不是“幻燈片”那么簡單&#xff01;第二章&#xff1a;探秘“時間魔法”&#xff1a;AnimateDiff…

2025深大計算機考研復試經驗貼(已上岸)

如果你在初試出分前看到此貼 我建議&#xff1a; 準備機試和簡歷&#xff0c;即使你不估分&#xff1a;因為如果要準備春招的話&#xff0c;也總要刷題和做簡歷的。盡早估分&#xff0c;查一下往年的復試線&#xff0c;如果有望進復試&#xff0c;可盡早開始準備。 Preface …

用Pygame開發桌面小游戲:從入門到發布

一、引言 Pygame是一個基于Python的跨平臺游戲開發庫,它提供了簡單易用的圖形、聲音和輸入處理功能,非常適合新手入門游戲開發。本文將以"經典游戲合集"項目為例,帶你一步步了解如何使用Pygame開發、打包和發布自己的桌面小游戲。 二、開發環境搭建 安裝Python:…

CSS backdrop-filter:給元素背景添加模糊與色調的高級濾鏡

在現代網頁設計中&#xff0c;半透明元素搭配背景模糊效果已成為流行趨勢 —— 從毛玻璃導航欄、模態框遮罩&#xff0c;到卡片懸停效果&#xff0c;這種設計能讓界面更具層次感和高級感。實現這一效果的核心 CSS 屬性&#xff0c;正是backdrop-filter。它能對元素背后的內容&a…

檢索增強生成(RAG) 緩存增強生成(CAG) 生成中檢索(RICHES) 知識庫增強語言模型(KBLAM)

以下是當前主流的四大知識增強技術方案對比&#xff0c;涵蓋核心原理、適用場景及最新發展趨勢&#xff0c;為開發者提供清晰的技術選型參考&#xff1a; &#x1f50d; 一、RAG&#xff08;檢索增強生成&#xff09;?? 核心原理?&#xff1a; 動態檢索外部知識庫&#xff0…

LLM(大語言模型)的工作原理 圖文講解

目錄 1. 條件概率&#xff1a;上下文預測的基礎 2. LLM 是如何“看著上下文寫出下一個詞”的&#xff1f; 補充說明&#xff08;重要&#xff09; &#x1f4cc; Step 1: 輸入處理 &#x1f4cc; Step 2: 概率計算 &#x1f4cc; Step 3: 決策選擇 &#x1f914; 一個有…

Python netifaces 庫詳解:跨平臺網絡接口與 IP 地址管理

一、前言 在現代網絡編程中&#xff0c;獲取本機的網絡接口信息和 IP 配置是非常常見的需求。 例如&#xff1a; 開發一個需要選擇合適網卡的 網絡服務&#xff1b;在多網卡環境下實現 流量路由與控制&#xff1b;在系統診斷工具中展示 IP/MAC 地址、子網掩碼、默認網關&#x…

HTML應用指南:利用POST請求獲取上海黃金交易所金價數據

上海黃金交易所&#xff08;SGE&#xff09;作為中國唯一經國務院批準、專門從事黃金等貴金屬交易的國家級市場平臺&#xff0c;自成立以來始終秉持“公開、公平、公正”的原則&#xff0c;致力于構建規范、高效、透明的貴金屬交易市場體系。交易所通過完善的交易機制、嚴格的風…

C++常見面試題-1.C++基礎

一、C 基礎 1.1 語言特性與區別C 與 C 的主要區別是什么&#xff1f;C 為何被稱為 “帶類的 C”&#xff1f; 主要區別&#xff1a;C 引入了面向對象編程&#xff08;OOP&#xff09;特性&#xff08;類、繼承、多態等&#xff09;&#xff0c;而 C 是過程式編程語言&#xff1…

Tomcat里catalina.sh詳解

在 Tomcat 中&#xff0c;catalina.sh&#xff08;Linux/macOS&#xff09;或 catalina.bat&#xff08;Windows&#xff09;是 核心的啟動和關閉腳本&#xff0c;用于控制 Tomcat 服務器的運行。它是 Tomcat 的“主控腳本”&#xff0c;負責設置環境變量、啟動/關閉 JVM 進程&…

STM32之MCU和GPIO

一、單片機MCU 1.1 單片機和嵌入式 嵌入式系統 以計算機為核心&#xff0c;tips&#xff1a;計算機【處理單元&#xff0c;內存 硬盤】 可以控制的外部設備&#xff0c;傳感器&#xff0c;電機&#xff0c;繼電器 嵌入式開發 數據源--> 處理器(CPU MCU MPU) --> 執行器 …

22_基于深度學習的桃子成熟度檢測系統(yolo11、yolov8、yolov5+UI界面+Python項目源碼+模型+標注好的數據集)

目錄 項目介紹&#x1f3af; 功能展示&#x1f31f; 一、環境安裝&#x1f386; 環境配置說明&#x1f4d8; 安裝指南說明&#x1f3a5; 環境安裝教學視頻 &#x1f31f; 二、數據集介紹&#x1f31f; 三、系統環境&#xff08;框架/依賴庫&#xff09;說明&#x1f9f1; 系統環…

數據結構:二叉樹oj練習

在講今天的題目之前&#xff0c;我們還需要講一下二叉樹的以下特點&#xff1a; 對任意一顆二叉樹&#xff0c;如果度為0的節點個數是n0&#xff0c;度為2的節點個數是n2&#xff0c;則有n0n21. 證明&#xff1a;二叉樹總的節點個數是n&#xff0c;那么有nn0n1n2 二叉樹的度為…

RabbitMQ高級特性——TTL、死信隊列、延遲隊列、事務、消息分發

目錄 一、TTL 1.1設置消息的TTL 1.2設置隊列的TTL 1.3兩者之間的區別 二、死信隊列 2.1死信的概念 2.2死信產生的條件&#xff1a; 2.3死信隊列的實現 死信隊列的工作原理 2.4常??試題 三、延遲隊列 3.1概念 3.2應用場景 3.3RabbitMQ 實現延遲隊列的核心原理 1…

神經網絡設計中關于BN歸一化(Normalization)的討論

在神經網絡的結構中&#xff0c;我們常常可以看見歸一化&#xff08;Normalization&#xff09;如BN的出現&#xff0c;無論是模型的backbone或者是neck的設計都與它有著重大的關系。 因此引發了我對它的思考&#xff0c;接下來我將從 是什么&#xff08;知識領域&#xff0c;誕…

MacOS 安全機制與“文件已損壞”排查完整指南

1. 背景說明macOS 為了保護系統安全&#xff0c;內置了多個安全機制&#xff1a;機制作用是否影響第三方 AppSIP (System Integrity Protection)保護系統關鍵文件/目錄不被篡改高風險 App/驅動可能受限Gatekeeper限制未簽名/未認證 App 運行阻止“未知開發者” App文件隔離屬性…

package.json文件中的devDependencies和dependencies對象有什么區別?

前端項目的package.json文件中&#xff0c;dependencies和devDependencies對象都用于指定項目所依賴的軟件包&#xff0c;但它們在項目的開發和生產環境中的使用有所不同。1.dependencies&#xff1a;dependencies是指定項目在生產環境中運行所需要的依賴項。這些依賴項通常包括…