[pytorch、學習] - 5.9 含并行連結的網絡(GoogLeNet)

參考

5.9 含并行連結的網絡(GoogLeNet)

在2014年的ImageNet圖像識別挑戰賽中,一個名叫GoogLeNet的網絡結構大放異彩。它雖然在名字上向LeNet致敬,但在網絡結構上已經很難看到LeNet的影子。GoogLeNet吸收了NiN中網絡串聯網絡的思想,并在此基礎上做了很大改進。在隨后的幾年里,研究人員對GoogLeNet進行了數次改進,本節將介紹這個模型系列的第一個版本。

5.9.1 Inception塊

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-fn8Tzn2Z-1594258539068)(attachment:image.png)]

Inception塊中可以自定義的超參數是每個層的輸出通道,我們以此來控制模型的復雜度

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device("cuda" if torch.cuda.is_available() else 'cpu')class Inception(nn.Module):def __init__(self, in_c, c1, c2, c3, c4):super(Inception, self).__init__()# 路線1: 單 1 x 1 卷積層self.p1_1 = nn.Conv2d(in_c, c1, kernel_size =1)# 路線2: 1 x 1 卷積層后接3 x  3卷積層self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size = 1)self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size = 3, padding = 1)# 路線3: 1 x 1卷積層后接5 x 5卷積層self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size = 1)self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size = 5, padding = 2)# 路線4: 3 x 3 最大池化層后接1 x 1卷積層self.p4_1 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)self.p4_2 = nn.Conv2d(in_c, c4, kernel_size = 1)def forward(self, x):p1 = F.relu(self.p1_1(x))p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))p4 = F.relu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim = 1)  # 在通道維上連結輸出

5.9.2 GoogLeNet模型

GoogLeNet跟VGG一樣,在主體卷積部分中使用5個模塊(block),每個模塊之間使用步幅為2的3x3最大池化層來減少輸出寬高。第一模塊使用一個64通道的7x7卷積層。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size =7, stride =2, padding =3),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
)

第二模塊使用2個卷積層: 首先通過64通道的 1x1 卷積層,然后是將通道增大3倍的 3x3 卷積層。它對應Inception塊中的第二條線路。

b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size = 1),nn.Conv2d(64, 192, kernel_size = 3, padding =1),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
)

第三模塊串聯2個完整的Inception塊。第一個Inception塊的輸出通道數為 64 + 128 + 32 + 32 = 256,其中4條路線的輸出通道數比例為 64: 128 : 32: 32 = 2: 4: 1: 1。其中第二、第三條路線先分別將輸入通道數減少至96/192 = 1/2 和 16/192 = 1/ 12后,再接上第二層卷積層。第二個Inception塊輸出通道數增至 128 +192 +96 + 64 = 480,每條線路的輸出通道數之比為 128: 192: 96: 64 = 4: 6 : 3 : 2。其中第二、第三條路線先分別將輸入通道數減少至 128/ 256 = 1/2 和 32/ 256 = 1/8。

b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192),(32, 96), 64),nn.MaxPool2d(kernel_size = 3, stride =2, padding =1)
)

第四模塊更加復雜。它串聯了5個Inception塊,其輸出通道數分別是192 + 208 + 48 + 64 = 512、160 + 224 + 64 + 64 = 512、128 + 256 + 64 + 64 = 512、 112 + 288 + 64 + 64 = 528 和 256 + 320 + 128 + 128 = 832。這些線路的通道數分配和第三模塊中類似,首先含3x3卷積層的第二條線路輸出最多通道,其次是僅含1x1卷積層的第一條線路,之后是含5x5卷積層的第3條線路和含3x3最大池化層的第4條線路。其中第二、第三條線路都會按比例減小通道數。這些比例在各個Inception塊中都略有不同。

b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (144, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size = 3, stride=2, padding =1)
)

第五模塊有輸出通道數為 256 + 320 + 128 + 128 = 832 和 834 + 384 + 128 + 128 = 1024的兩個Inception塊。其中每條線路的通道數的分配思路和第三、第四模塊中的一致,只是在具體數值上有所不同。需要注意的是,第五模塊的后面緊跟輸出層,該模塊同NiN一樣使用全局平均池化層來將每個通道的寬和高變成1。最后我們將輸出變成二維數組后接上一個輸出個數為標簽類別數的全連接層。

b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),d2l.GlobalAvgPool2d()
)net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))print(net)

在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述
將輸入的高和寬從224降到96來簡化計算,下面演示各個模塊之間的輸出的形狀變化

net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))
X = torch.rand(1, 1, 96, 96)
for blk in net.children():X = blk(X)print("output shape: ", X.shape)

在這里插入圖片描述

5.9.3 獲取數據和訓練模型

batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250117.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250117.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250117.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

mybits注解詳解

一、mybatis 簡單注解 關鍵注解詞 : Insert : 插入sql , 和xml insert sql語法完全一樣 Select : 查詢sql, 和xml select sql語法完全一樣 Update : 更新sql, 和xml update sql語法完全一樣 Delete : 刪除sql, 和xml d…

使用python裝飾器計算函數運行時間的實例

使用python裝飾器計算函數運行時間的實例 裝飾器在python里面有很重要的作用, 如果能夠熟練使用,將會大大的提高工作效率 今天就來見識一下 python 裝飾器,到底是怎么工作的。 本文主要是利用python裝飾器計算函數運行時間 一些需要精確的計算…

SQLServer用存儲過程實現插入更新數據

實現 1)有同樣的數據,直接返回(返回值:0)。 2)有主鍵同樣。可是數據不同的數據。進行更新處理(返回值:2); 3)沒有數據,進行插入數據處…

[pytorch、學習] - 9.1 圖像增廣

參考 9.1 圖像增廣 在5.6節(深度卷積神經網絡)里我們提過,大規模數據集是成功應用神經網絡的前提。圖像增廣(image augmentation)技術通過對訓練圖像做一系列隨機改變,來產生相似但又不相同的訓練樣本,從而擴大訓練數據集的規模。圖像增廣的另一種解釋是,隨機改變訓練樣本可以…

mysql綠色版安裝

導讀:MySQL是一款關系型數據庫產品,官網給出了兩種安裝包格式:MSI和ZIP。MSI格式是圖形界面安裝方式,基本只需下一步即可,這篇文章主要介紹ZIP格式的安裝過程。ZIP Archive版是免安裝的。只要解壓就行了。 一、首先下…

在微信瀏覽器字體被調大導致頁面錯亂的解決辦法

iOS的解決方案是覆蓋掉微信的樣式: body { /* IOS禁止微信調整字體大小 */-webkit-text-size-adjust: 100% !important; } 安卓的解決方案是通過 WeixinJSBridge 對象將網頁的字體大小設置為默認大小,并且重寫設置字體大小的方法,讓用戶不能在…

[pytorch、學習] - 9.2 微調

參考 9.2 微調 在前面得一些章節中,我們介紹了如何在只有6萬張圖像的Fashion-MNIST訓練數據集上訓練模型。我們還描述了學術界當下使用最廣泛規模圖像數據集ImageNet,它有超過1000萬的圖像和1000類的物體。然而,我們平常接觸到數據集的規模通常在這兩者之間。 假設我們想從圖…

Springboot默認加載application.yml原理

Springboot默認加載application.yml原理以及擴展 SpringApplication.run(…)默認會加載classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默認加載一套默認的配置文件demo.properties,讓開發人員實現“零”配置開發,但…

java 集合(Set接口)

Set接口:無序集合,不允許有重復值,允許有null值 存入與取出的順序有可能不一致 HashSet:具有set集合的基本特性,不允許重復值,允許null值 底層實現是哈希表結構 初始容量為16 保存自定義對象時,保證數據的唯…

關于mac機抓包的幾點基礎知識

1. 我使用的抓包工具為WireShark,以下操作按我當前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的區別。 2. 將mac設置為熱點:打開系統偏好設置,點擊共享: 然后點擊WIFI選項,設置WIFI名…

SpringBoot啟動如何加載application.yml配置文件

一、前言 在spring時代配置文件的加載都是通過web.xml配置加載的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多數都是通過指定路徑的文件名的形式去告訴spring該加載哪個文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小結(分支拉取、版本回退)

1. 首次(fork項目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [遠程倉庫的分支名]克隆的是主干網絡 2. 再次拉取代碼 $ git pull [master下選擇分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得進入主倉(不是自己的遠程倉)然后…

MYSQL 查看最大連接數和修改最大連接數

MySQL查看最大連接數和修改最大連接數 1、查看最大連接數show variables like %max_connections%;2、修改最大連接數set GLOBAL max_connections 200; 以下的文章主要是向大家介紹的是MySQL最大連接數的修改&#xff0c;我們大家都知道MySQL最大連接數的默認值是100, 這個數值…

阿里云服務器端口開放對外訪問權限

登陸阿里云管理控制臺 點擊自己的實例 點擊安全組配置 點擊配置規則 點擊添加安全組規則 配置出入放心&#xff0c;和開放的端口號&#xff0c;以及那些網段可以訪問&#xff0c;這里設置所有網段都可以訪問 轉自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

數據分頁功能是我們軟件系統中必備的功能&#xff0c;在持久層使用mybatis的情況下&#xff0c;pageHelper來實現后臺分頁則是我們常用的一個選擇&#xff0c;所以本文專門類介紹下。 PageHelper原理 相關依賴 <dependency><groupId>org.mybatis</groupId>&…

10-多寫一個@Autowired導致程序崩了

再是javaweb實驗六中&#xff0c;是讓我們改代碼&#xff0c;讓它跑起來&#xff0c;結果我少注釋了一個&#xff0c;導致一直報錯&#xff0c;檢查許久沒有找到&#xff0c;最后通過代碼替換逐步查找&#xff0c;才發現問題。 轉載于:https://www.cnblogs.com/zhumengdexiaoba…

Java class不分32位和64位

1、32位JDK編譯的java class在32位系統和64位系統下都可以運行&#xff0c;64位系統兼容32位程序&#xff0c;可以理解。2、無論是Linux還是Windows平臺下的JDK編譯的java class在Linux、Windows平臺下通用&#xff0c;Java跨平臺特性。3、64位JDK編譯的java class在32位的系統…

包裝對象

原文地址&#xff1a;https://wangdoc.com/javascript/ 定義 對象是JavaScript語言最主要的數據類型&#xff0c;三種原始類型的值--數值、字符串、布爾值--在一定條件下&#xff0c;也會自動轉為對象&#xff0c;也就是原始類型的包裝對象。所謂包裝對象&#xff0c;就是分別與…

[C++] 轉義序列

參考 C Primer(第5版)P36 名稱轉義序列換行符\n橫向制表符\t報警(響鈴)符\a縱向制表符\v退格符\b雙引號"反斜杠\問號?單引號’回車符\r進紙符\f

vue使用(二)

本節目標&#xff1a; 1.數據路徑的三種方式 2.{{}}和v-html的區別 1.綁定圖片的路徑 方法一&#xff1a;直接寫路徑 <img src"http://pic.baike.soso.com/p/20140109/20140109142534-188809525.jpg"> 方法二&#xff1a;在data中寫路徑&#xff0c;在…