昇思25天學習打卡營第13天 | ShuffleNet圖像分類

在這里插入圖片描述

ShuffleNet網絡介紹

ShuffleNetV1是曠視科技提出的一種計算高效的CNN模型,和MobileNet, SqueezeNet等一樣主要應用在移動端,所以模型的設計目標就是利用有限的計算資源來達到最好的模型精度。ShuffleNetV1的設計核心是引入了兩種操作:Pointwise Group Convolution和Channel Shuffle,這在保持精度的同時大大降低了模型的計算量。因此,ShuffleNetV1和MobileNet類似,都是通過設計更高效的網絡結構來實現模型的壓縮和加速。

如下圖所示,ShuffleNet在保持不低的準確率的前提下,將參數量幾乎降低到了最小,因此其運算速度較快,單位參數量對模型準確率的貢獻非常高。

圖片來源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.

模型架構

ShuffleNet最顯著的特點在于對不同通道進行重排來解決Group Convolution帶來的弊端。通過對ResNet的Bottleneck單元進行改進,在較小的計算量的情況下達到了較高的準確率。

Pointwise Group Convolution

Group Convolution(分組卷積)原理如下圖所示,相比于普通的卷積操作,分組卷積的情況下,每一組的卷積核大小為in_channels/gkk,一共有g組,所有組共有(in_channels/gkk)*out_channels個參數,是正常卷積參數的1/g。分組卷積中,每個卷積核只處理輸入特征圖的一部分通道,其優點在于參數量會有所降低,但輸出通道數仍等于卷積核的數量。

在這里插入圖片描述
Depthwise Convolution(深度可分離卷積)將組數g分為和輸入通道相等的in_channels,然后對每一個in_channels做卷積操作,每個卷積核只處理一個通道,記卷積核大小為1kk,則卷積核參數量為:in_channelskk,得到的feature maps通道數與輸入通道數相等;

Pointwise Group Convolution(逐點分組卷積)在分組卷積的基礎上,令每一組的卷積核大小為 1×1
,卷積核參數量為(in_channels/g11)*out_channels。

%%capture captured_output
# 實驗環境已經預裝了mindspore==2.2.14,如需更換mindspore版本,可更改下面mindspore的版本號
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensorclass GroupConv(nn.Cell):def __init__(self, in_channels, out_channels, kernel_size,stride, pad_mode="pad", pad=0, groups=1, has_bias=False):super(GroupConv, self).__init__()self.groups = groupsself.convs = nn.CellList()for _ in range(groups):self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,kernel_size=kernel_size, stride=stride, has_bias=has_bias,padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))def construct(self, x):features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)outputs = ()for i in range(self.groups):outputs = outputs + (self.convs[i](features[i].astype("float32")),)out = ops.cat(outputs, axis=1)return out

Channel Shuffle

Group Convolution的弊端在于不同組別的通道無法進行信息交流,堆積GConv層后一個問題是不同組之間的特征圖是不通信的,這就好像分成了g個互不相干的道路,每一個人各走各的,這可能會降低網絡的特征提取能力。這也是Xception,MobileNet等網絡采用密集的1x1卷積(Dense Pointwise Convolution)的原因。

為了解決不同組別通道“近親繁殖”的問題,ShuffleNet優化了大量密集的1x1卷積(在使用的情況下計算量占用率達到了驚人的93.4%),引入Channel Shuffle機制(通道重排)。這項操作直觀上表現為將不同分組通道均勻分散重組,使網絡在下一層能處理不同組別通道的信息。

在這里插入圖片描述
如下圖所示,對于g組,每組有n個通道的特征圖,首先reshape成g行n列的矩陣,再將矩陣轉置成n行g列,最后進行flatten操作,得到新的排列。這些操作都是可微分可導的且計算簡單,在解決了信息交互的同時符合了ShuffleNet輕量級網絡設計的輕量特征。

在這里插入圖片描述

ShuffleNet模塊

如下圖所示,ShuffleNet對ResNet中的Bottleneck結構進行由(a)到(b), ?的更改:

  1. 將開始和最后的 1×1
    卷積模塊(降維、升維)改成Point Wise Group Convolution;

  2. 為了進行不同通道的信息交流,再降維之后進行Channel Shuffle;

  3. 降采樣模塊中, 3×3 Depth Wise Convolution的步長設置為2,長寬降為原來的一般,因此shortcut中采用步長為2的 3×3
    平均池化,并把相加改成拼接。

在這里插入圖片描述

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return outdef channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

構建ShuffleNet網絡

ShuffleNet網絡結構如下圖所示,以輸入圖像 224×224
,組數3(g = 3)為例,首先通過數量24,卷積核大小為 3×3
,stride為2的卷積層,輸出特征圖大小為 112×112
,channel為24;然后通過stride為2的最大池化層,輸出特征圖大小為 56×56
,channel數不變;再堆疊3個ShuffleNet模塊(Stage2, Stage3, Stage4),三個模塊分別重復4次、8次、4次,其中每個模塊開始先經過一次下采樣模塊(上圖?),使特征圖長寬減半,channel翻倍(Stage2的下采樣模塊除外,將channel數從24變為240);隨后經過全局平均池化,輸出大小為 1×1×960
,再經過全連接層和softmax,得到分類概率。

在這里插入圖片描述

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

模型訓練和評估

采用CIFAR-10數據集對ShuffleNet進行預訓練。

訓練集準備與加載
采用CIFAR-10數據集對ShuffleNet進行預訓練。CIFAR-10共有60000張32*32的彩色圖像,均勻地分為10個類別,其中50000張圖片作為訓練集,10000圖片作為測試集。如下示例使用mindspore.dataset.Cifar10Dataset接口下載并加載CIFAR-10的訓練集。目前僅支持二進制版本(CIFAR-10 binary version)。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()# 由于時間原因,epoch = 5,可根據需求進行調整model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

在這里插入圖片描述

學習心得

通過本次學習,我不僅掌握了ShuffleNetV1的網絡結構和實現方法,還深入理解了分組卷積和通道重排在提高模型效率中的作用。未來,我希望能夠進一步探索ShuffleNetV2以及其他高效模型的設計與應用,并嘗試將其應用于更多復雜的數據集和任務中。同時,我還計劃研究模型壓縮和加速的其他技術,如模型剪枝和量化,以進一步提升模型的應用性能。
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/42699.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/42699.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/42699.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ExcelVBA運用Excel的【條件格式】(二)

ExcelVBA運用Excel的【條件格式】(二)前面知識點回顧1. 訪問 FormatConditions 集合 Range.FormatConditions2. 添加條件格式 FormatConditions.Add 方法語法表達式。添加 (類型、 運算符、 Expression1、 Expression2)3. 修改或刪除條件格式4. …

如何在Spring Boot中實現動態多語言支持

如何在Spring Boot中實現動態多語言支持 大家好,我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編,也是冬天不穿秋褲,天冷也要風度的程序猿! 一、引言 隨著全球化市場的發展,多語言支持已經成為現代…

密碼技術中分組模式解析

目錄 1. 概述 2. ECB模式 2.1 概述 2.2 ECB模式的加密 2.3 ECB模式的解密 2.4 優點 2.5 缺點 3. CBC模式【推薦】 3.1 概述 3.2 CBC模式的加密 3.3 CBC模式的解密 3.4 優點 3.5 缺點 4. CFB模式 4.1 概述 4.2 CFB模式的加密 4.3 CFB模式的解密 4.4 優點 4.…

智慧地產視覺監控系統開源了,系統采用多種優化技術,提高系統的響應速度和資源利用率

智慧地產視覺監控平臺是一款功能強大且簡單易用的實時算法視頻監控系統。它的愿景是最底層打通各大芯片廠商相互間的壁壘,省去繁瑣重復的適配流程,實現芯片、算法、應用的全流程組合,從而大大減少企業級應用約95%的開發成本。用戶只需在界面上…

Python打開Excel文檔并讀取數據

Python 版本 目前 Python 3 版本為主流版本,這里測試的版本是:Python 3.10.5。 常用庫說明 Python 操作 Excel 的常用庫有:xlrd、xlwt、xlutils、openpyxl、pandas。這里主要說明下 Excel 文檔 .xls 格式和 .xlsx 格式的文檔打開和讀取。 …

Drools開源業務規則引擎(二)- Drools規則語言(DRL)

文章目錄 1.DRL文件的組成:2.package3.import4.function5.query6.declare7.global8.rule8.1.規則屬性8.2.LHS8.2.1.語法格式8.2.2.運算符優先級8.2.3.特殊的運算符1.matches, not matches2.contains, not contains3.memberOf, not memberOf4.in, notin5.soundslike6…

Powershell 獲取電腦保存的所有wifi密碼

一. 知識點 netsh wlan show profiles 用于顯示計算機上已保存的無線網絡配置文件 Measure-Object 用于統計數量 [PSCustomObject]{ } 用于創建Powershell對象 [math]::Round 四舍五入 Write-Progress 顯示進度條 二. 代碼 只能獲取中文Windows操作系統的wifi密碼如果想獲取…

護網在即,助力安服仔漏洞掃描~

整合了個漏掃系統,安服仔必備~ 使用場景 網前布防,漏洞掃描,資產梳理 使用方法: 啟動虛擬機后運行命令: ./StartSystemScript.sh 輸入密碼attack 啟動完成后瀏覽器打開網站: http://IP:5000 相關賬戶…

Git 常用命令備忘

1、刪除 (1)、git push origin --delete dev 刪除遠程分支 (2)、git branch -d dev 刪除本地分支 git branch -D dev 強制刪除本地分支 2、創建分支 (1)、git checkout -b dev 創建本地分支 (2)、git push origin dev 創建遠程分支,此時本地分支與遠程…

02-android studio實現下拉列表+單選框+年月日功能

一、下拉列表功能 1.效果圖 2.實現過程 1&#xff09;添加組件 <LinearLayoutandroid:layout_width"match_parent"android:layout_height"wrap_content"android:layout_marginLeft"20dp"android:layout_marginRight"20dp"android…

表單驗證的藝術:WebKit 支持 HTML 表單的全面解析

表單驗證的藝術&#xff1a;WebKit 支持 HTML 表單的全面解析 在 Web 開發的多彩世界中&#xff0c;表單是用戶與網頁交互的重要橋梁。WebKit 作為眾多現代瀏覽器的渲染引擎&#xff0c;提供了強大的 HTML 表單支持和驗證功能。本文將深入探討 WebKit 如何支持 HTML 表單和進行…

力扣225題解析:使用隊列實現棧的三種解法(Java實現)

引言 在算法和數據結構中&#xff0c;如何用隊列實現棧是一個常見的面試題和實際應用問題。本文將探討力扣上的第225題&#xff0c;通過不同的方法來實現這一功能&#xff0c;并分析各種方法的優劣和適用場景。 問題介紹 力扣225題目要求我們使用隊列實現棧的下列操作&#…

【CMake】基本概念和快速入門

#1. install 是什么 在CMake或項目構建中&#xff0c;install步驟通常指的是將生成的可執行文件、庫文件、頭文件和其他資源復制到指定的安裝目錄&#xff0c;以便進行發布、部署或在其他項目中使用。這個過程通常包括以下內容&#xff1a; 1. 安裝目標 安裝目標是指需要安裝…

運維系列.Nginx中使用HTTP壓縮功能

運維專題 Nginx中使用HTTP壓縮功能 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550…

【刷題匯總--字符串中找出連續最長的數字串、島嶼數量、拼三角】

C日常刷題積累 今日刷題匯總 - day0071、字符串中找出連續最長的數字串1.1、題目1.2、思路1.3、程序實現 -- 比較1.4、程序實現 -- 雙指針 2、島嶼數量2.1、題目2.2、思路2.3、程序實現 - dfs 3、拼三角3.1、題目3.2、思路3.3、程序實現 -- 蠻力法3.4、程序實現 -- 巧解(單調性…

pwm 呼吸燈(如果燈一直亮或者一直滅)

&#xff08;這個文章收藏在我的csdn keil文件夾下面&#xff09; 如果這樣設置預分頻和計數周期&#xff0c;那么算出來的pwm頻率如下 人眼看起來就只能是一直亮或者滅&#xff0c;因為pwm的頻率太高了&#xff0c;但是必須是頻率夠高&#xff0c;才能實現呼吸燈的緩慢亮緩慢…

SPL-404:如何徹底改變Solana上的NFT與DeFi

在不斷發展的數字資產領域中&#xff0c;非同質化Token&#xff08;NFT&#xff09;已成為一股革命性力量&#xff0c;徹底改變了我們對數字所有權的看法和互動方式。從藝術和收藏品到游戲和虛擬房地產&#xff0c;NFT吸引了創作者、投資者和愛好者的想象力。 本指南將帶您進入…

MySQL數據庫文件在Linux下存放位置

數據庫文件默認在&#xff1a;cd /usr/share/mysql 配置文件默認在&#xff1a;/etc/my.cnf 數據庫目錄&#xff1a;/var/lib/mysql/ 配置文件&#xff1a;/usr/share/mysql(mysql.server命令及配置文件) 相關命令&#xff1a;/usr/bin(mysqladmin、mysqldump等命令)(*mysql的一…

MyBatisPlus-分頁插件的基本使用

目錄 配置插件 使用分頁API 配置插件 首先&#xff0c;要在配置類中注冊MyBatisPlus的核心插件&#xff0c;同時添加分頁插件。&#xff08;可以放到config軟件包下&#xff09; 可以看到&#xff0c;我們定義了一個配置類&#xff0c;在配置類里聲明了一個Bean,這個Bean的名…

排序 -- 計數排序以及對排序的總結

到了這篇文章就說明常見的排序我們就快要講完了&#xff0c;那這篇文章我們就講一下非比較排序--計數排序。 一、非比較排序 1.基本思想 計數排序又稱為鴿巢原理&#xff0c;是對哈希直接定址法的變形應用。 操作步驟&#xff1a; 統計相同元素出現次數 根據統計的結果將序列…