PyTorch從零開始實現ResNet

文章目錄

    • 代碼實現
    • 參考

代碼實現

本文實現 ResNet原論文 Deep Residual Learning for Image Recognition 中的50層,101層和152層殘差連接。
在這里插入圖片描述
代碼中使用基礎殘差塊這個概念,這里的基礎殘差塊指的是上圖中紅色矩形圈出的內容:從上到下分別使用3, 4, 6, 3個基礎殘差塊,每個基礎殘差塊由三個卷積層組成,核大小分別為1x1, 3x3, 1x1 。

殘差連接的結構

在這里插入圖片描述

復現代碼如下:

import torch
import torch.nn as nn# 基礎殘差塊,后面ResNet要多次重復使用該塊
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)# x 和 identity形狀一致,才能相加if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的層self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函數:調用block基礎殘差塊,構造ResNet的每一層def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []# 修改形狀,使得殘差連接可以相加:x + identityif stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 構造ResNet50層:默認圖像通道3,分類類別為1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 構造ResNet101層  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 構造ResNet152層  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 測試輸出y的形狀是否滿足1000類
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

參考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/43540.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/43540.shtml
英文地址,請注明出處:http://en.pswp.cn/news/43540.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

感覺和身邊其他人有差距怎么辦?

雖然清楚知識需要靠時間沉淀,但在看到自己做不出來的題別人會做,自己寫不出的代碼別人會寫時還是會感到焦慮怎么辦? 你是否也因為自身跟周圍人的差距而產生過迷茫,這份迷茫如今是被你克服了還是仍舊讓你感到困擾? 下…

LabVIEW開發最小化5G系統測試平臺

LabVIEW開發最小化5G系統測試平臺 由于具有大量存儲能力和數據的應用程序的智能手機的激增,當前一代產品被迫提高其吞吐效率。正交頻分復用由于其卓越的品質,如單抽頭均衡和具有成本效益的實施,現在被廣泛用作物理層技術。這些好處是以嚴格的…

ElasticSearch索引庫、文檔、RestClient操作

文章目錄 一、索引庫1、mapping屬性2、索引庫的crud 二、文檔的crud三、RestClient 一、索引庫 es中的索引是指相同類型的文檔集合,即mysql中表的概念 映射:索引中文檔字段的約束,比如名稱、類型 1、mapping屬性 mapping映射是對索引庫中文…

Elasticsearch在部署時,對Linux的設置有哪些優化方法?

部署Elasticsearch時,可以通過優化Linux系統的設置來提升性能和穩定性。以下是一些常見的優化方法: 1.文件描述符限制 Elasticsearch需要大量的文件描述符來處理數據和連接,所以確保調整系統的文件描述符限制。可以通過修改 /etc/security/…

Docker-compose搭建Git私服

1. 新建個專用的目錄,然后在里面新建個docker-compose.yml文件: (gitlab-ce是社區版,當然還有ee,是商業版) version: 3.6 …

es自定義分詞器支持數字字母分詞,中文分詞器jieba支持添加禁用詞和擴展詞典

自定義分析器,分詞器 PUT http://xxx.xxx.xxx.xxx:9200/test_index/ {"settings": {"analysis": {"analyzer": {"char_test_analyzer": {"tokenizer": "char_test_tokenizer","filter": [&…

公網遠程連接Redis數據庫詳解

文章目錄 1. Linux(centos8)安裝redis數據庫2. 配置redis數據庫3. 內網穿透3.1 安裝cpolar內網穿透3.2 創建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一個固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址連接 前言 潔潔的個人主頁 我就問你有沒有發揮&#xff0…

ssh免密登陸報錯ERROR: @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

問題描述: 在日常的運維中需要做ssh的免密登陸有提示如下的報錯內容: [rootpaas-harbor01 cce-v5.2.3]# ssh-copy-id 192.45.66.14 /usr/bin/ssh-copy-id: INFO: Source of key(s) to be installed: "/root/.ssh/id_rsa.pub" /usr/bin/ssh-c…

通訊錄實現【C語言】

目錄 前言 一、整體邏輯分析 二、實現步驟 1、創建菜單和多次操作問題 2、創建通訊錄 3、初始化通訊錄 4、添加聯系人 5、顯示聯系人 6、刪除指定聯系人 ?7、查找指定聯系人 8、修改聯系人信息 9、排序聯系人信息 三、全部源碼 前言 我們上期已經詳細的介紹了自定…

Java SpringBoot Vue ERP系統

系統介紹 該ERP系統基于SpringBoot框架和SaaS模式,支持多租戶,專注進銷存財務生產功能。主要模塊有零售管理、采購管理、銷售管理、倉庫管理、財務管理、報表查詢、系統管理等。支持預付款、收入支出、倉庫調撥、組裝拆卸、訂單等特色功能。擁有商品庫存…

ubuntu設置共享文件夾成功后卻不顯示找不到(已解決)

1.首先輸下面命令查看是否真的設置成功共享文件夾 vmware-hgfsclient如果確實已經設置過共享文件夾將輸出window下共享文件夾名字 2.確認自己已設置共享文件夾后輸入下面的命令 //如果之前沒有命令包則先執行sudo apt-get install open-vm-tools sudo vmhgfs-fuse .host:/ /mn…

十六、Spring Cloud Sleuth 分布式請求鏈路追蹤

目錄 一、概述1、為什么出出現這個技術?需要解決哪些問題2、是什么?3、解決 二、搭建鏈路監控步驟1、下載運行zipkin2、服務提供者3、服務調用者4、測試 一、概述 1、為什么出出現這個技術?需要解決哪些問題 2、是什么? 官網&am…

spss---如何使用信度分析以及案例分析

信度分析 問卷調查法是教育研究中廣泛采用的一種調查方法,根據調查目的設計的調查問卷是問卷調查法獲取信息的工具,其質量高低對調查結果的真實性、適用性等具有決定性的作用。 為了保證問卷具有較高的可靠性和有效性,在形成正式問卷之 前&…

CLion:最好用的c/c++編寫工具(最詳細安裝教程)

目錄 一.前言介紹 1.下載安裝 1.1右上角點擊下載 1.2選擇自己操作系統,然后點擊下載 1.3選擇next 1.4 更改路徑 1.5D盤最好 1.6 按照我的選擇配置環境 1.7install安裝 1.8 安裝完成 2、mingw64安裝 2.1下載資源壓縮包 2.2mingw64放入到合適的位置,…

Redis五大基本數據類型及其使用場景

文章目錄 **一 什么是NoSQL?****二 redis是什么?****三 redis五大基本類型**1 String(字符串)**應用場景** 2 List(列表)**應用場景** 3 Set(集合)4 sorted set(有序集合…

高級藝術二維碼制作教程

最近不少關于二維碼制作的,而且都是付費。大概就是一個好看的二維碼,掃描后跳轉網址。本篇文章使用Python來實現,這么簡單花啥錢呢?學會,拿去賣便宜點吧。 文章目錄 高級二維碼制作環境安裝普通二維碼藝術二維碼動態 …

【LVS】2、部署LVS-DR群集

LVS-DR數據包的流向分析 1.客戶端發送請求到負載均衡器,請求的數據報文到達內核空間; 2.負載均衡服務器和正式服務器在同一個網絡中,數據通過二層數據鏈路層來傳輸; 3.內核空間判斷數據包的目標IP是本機VIP,此時IP虛…

批量將Excel中的第二列內容從拼音轉換為漢字

要批量將Excel中的第二列內容從拼音轉換為漢字,您可以使用Python的openpyxl庫來實現。下面是一個示例代碼,演示如何讀取Excel文件并將第二列內容進行拼音轉漢字: from openpyxl import load_workbook from xpinyin import Pinyin # 打開Exce…

Android kotlin系列講解(入門篇)使用Intent在Activity之間穿梭

<<返回總目錄 上一篇:Android kotlin系列講解(入門篇)Activity的理解與基本用法 文章目錄 1、使用顯式Intent2、使用隱式Intent3、更多隱式Intent的用法4、向下一個Activity傳遞數據5、返回數據給上一個Activity1、使用顯式Intent 你應該已經對創建Activity的流程比較…

SASS 學習筆記

SASS 學習筆記 總共會寫兩個練手項目&#xff0c;成品在 https://goldenaarcher.com/scss-study 可以看到&#xff0c;代碼在 https://github.com/GoldenaArcher/scss-study。 什么是 SASS SASS 是 CSS 預處理&#xff0c;它提供了變量&#xff08;雖然現在 CSS 也提供了&am…