【MindSpore學習打卡】應用實踐-計算機視覺-ShuffleNet圖像分類:從理論到實踐

在當今的深度學習領域,卷積神經網絡(CNN)已經成為圖像分類任務的主流方法。然而,隨著網絡深度和復雜度的增加,計算資源的消耗也顯著增加,特別是在移動設備和嵌入式系統中,這種資源限制尤為突出。ShuffleNet作為一種高效的卷積神經網絡,通過引入Pointwise Group Convolution和Channel Shuffle兩種操作,大大降低了計算量,同時保持了較高的分類精度。在本篇博客中,我們將詳細探討ShuffleNet的設計原理,并通過MindSpore框架實現ShuffleNet在CIFAR-10數據集上的訓練與評估,幫助讀者更好地理解和應用這一高效的網絡結構。

ShuffleNet網絡介紹

ShuffleNetV1是曠視科技提出的一種計算高效的CNN模型,主要應用在移動端。其設計核心在于引入了兩種操作:Pointwise Group Convolution和Channel Shuffle。這些操作在保持精度的同時,大大降低了模型的計算量。

Pointwise Group Convolution

Pointwise Group Convolution:我們在代碼中定義了一個GroupConv類,用于實現逐點分組卷積。這種卷積操作通過將輸入特征圖分成多個組,每組單獨進行卷積操作,從而顯著減少了參數量和計算量。具體來說,逐點分組卷積的卷積核大小為 1 × 1 1 \times 1 1×1,這使得每個卷積核只作用于一個通道,進一步降低了計算復雜度。

Group Convolution

class GroupConv(nn.Cell):def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):super(GroupConv, self).__init__()self.groups = groupsself.convs = nn.CellList()for _ in range(groups):self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,kernel_size=kernel_size, stride=stride, has_bias=has_bias,padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))def construct(self, x):features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)outputs = ()for i in range(self.groups):outputs = outputs + (self.convs[i](features[i].astype("float32")),)out = ops.cat(outputs, axis=1)return out

Channel Shuffle

Channel Shuffle:為了克服分組卷積帶來的不同組別通道無法進行信息交流的問題,ShuffleNet引入了Channel Shuffle機制。我們在代碼中實現了一個channel_shuffle方法,通過對通道進行重排,使得不同組別的通道能夠進行信息交互。這一步驟在保持網絡高效性的同時,增強了特征的表達能力。
Channel Shuffle

def channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

ShuffleNet模塊

ShuffleNet模塊:在ShuffleNet模塊中,我們結合了Pointwise Group Convolution和Channel Shuffle,并在降采樣模塊中引入了步長為2的Depth Wise Convolution。這種設計不僅提高了網絡的計算效率,還保證了特征提取的有效性。在代碼實現中,我們通過ShuffleV1Block類定義了ShuffleNet的基本模塊,并在其中實現了上述操作。

ShuffleNet模塊

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return out

構建ShuffleNet網絡

ShuffleNet網絡結構如下圖所示。以輸入圖像 224 × 224 224 \times 224 224×224,組數3(g = 3)為例,經過多個ShuffleNet模塊和全局平均池化,最終得到分類結果。

ShuffleNet網絡結構

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

模型訓練和評估

模型訓練和評估:在訓練部分,我們使用了CIFAR-10數據集,并通過數據增強技術(如隨機裁剪和水平翻轉)來提高模型的泛化能力。我們定義了ShuffleNet網絡,并使用交叉熵損失函數和Momentum優化器進行訓練。在評估部分,我們加載訓練好的模型,并在測試集上進行評估,計算模型的Top-1和Top-5準確率,以全面衡量模型的性能。

訓練集準備與加載

首先下載并加載CIFAR-10數據集。CIFAR-10共有60000張32x32的彩色圖像,分為10個類別,其中50000張圖片作為訓練集,10000張圖片作為測試集。

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transformsdef get_dataset(train_dataset_path, batch_size, usage):image_trans = []if usage == "train":image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]elif usage == "test":image_trans = [vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]label_trans = transforms.TypeCast(ms.int32)dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)dataset = dataset.map(image_trans, 'image')dataset = dataset.map(label_trans, 'label')dataset = dataset.batch(batch_size, drop_remainder=True)return datasetdataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()

模型訓練

定義ShuffleNet網絡,并使用交叉熵損失函數和Momentum優化器進行訓練。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

在這里插入圖片描述

模型評估

在CIFAR-10測試集上對訓練好的模型進行評估。

from mindspore import load_checkpoint, load_param_into_netdef test():mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")net = ShuffleNetV1(model_size="2.0x", n_class=10)param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")load_param_into_net(net, param_dict)net.set_train(False)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),'Top_5_Acc': Top5CategoricalAccuracy()}model = Model(net, loss_fn=loss, metrics=eval_metrics)start_time = time.time()res = model.eval(dataset, dataset_sink_mode=False)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \+ "', time: " + hour + "h " + minute + "m " + second + "s"print(log)filename = './eval_log.txt'with open(filename, 'a') as file_object:file_object.write(log + '\n')if __name__ == '__main__':test()

在這里插入圖片描述

模型預測

在CIFAR-10測試集上對模型進行預測,并將預測結果可視化。

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as dsnet = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方為預測的結果,下方為推理效果圖片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:plt.subplot(2, 8, index+1)plt.title('{}'.format(class_dict[pred[index]]))index += 1plt.imshow(image)plt.axis("off")
plt.show()

在這里插入圖片描述

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/39698.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/39698.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/39698.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

25計算機考研,這些學校雙非閉眼入,性價比超高!

計算機考研,好的雙非院校也很多! 對于一些二本準備考研的同學來說,沒必要一直盯著985/211這些院校,競爭激烈不說,容易當陪跑,下面這些就是不錯的雙非院校: 燕山大學南京郵電大學南京信息工程大…

WPS-Word文檔表格分頁

一、問題描述 這種情況不好描述 就是像這種表格內容,但是會有離奇的分頁的情況。這種情況以前的錯誤解決辦法就是不斷地調整表格的內容以及間隔顯得很亂,于是今天去查了解決辦法,現在學會了記錄一下避免以后忘記了。 二、解決辦法 首先記…

《昇思25天學習打卡營第5天 | mindspore 網絡構建 Cell 常見用法》

1. 背景: 使用 mindspore 學習神經網絡,打卡第五天; 2. 訓練的內容: 使用 mindspore 的 nn.Cell 構建常見的網絡使用方法; 3. 常見的用法小節: 支持一系列常用的 nn 的操作 3.1 nn.Cell 網絡構建&…

【FFmpeg】關鍵結構體的初始化和釋放(AVFormatContext、AVIOContext等)

目錄 1.AVFormatContext1.1 初始化(avformat_alloc_context)1.2 釋放(avformat_free_context) 2.AVIOContext2.1 初始化(avio_alloc_context)2.2 釋放(avio_context_free) 3. AVStre…

8.SQL注入-基于insert,update利用案例

SQL注入-基于insert/update利用案例 sql語句正常插入表中的數據 insert into member(username,pw,sex,phonenum,address,email) values(xiaoqiang,1111,1,2,3,4); select * from member;例如插入小強數據,如圖所示: 采用or這個運算符,構造…

實測有效:Win11右鍵默認顯示更多

Win11最大的變化之一莫過于右鍵菜單發生了變化,最大的問題是什么,是右鍵菜單很多時候需要點兩次,實在是反人類 第一步 復制以下命令直接運行: reg.exe add "HKCU\Software\Classes\CLSID\{86ca1aa0-34aa-4e8b-a509-50c905ba…

python_zabbix

zabbix官網地址:19. API19. APIhttps://www.zabbix.com/documentation/4.2/zh/manual/api 每個版本可以有些差異,選擇目前的版本在查看對于的api接口#token接口代碼 import requests apiurl "http://zabbix地址/api_jsonrpc.php" data {&quo…

web的學習和開發

這個使同步和異步的區別 今天主要就是學了一些前端,搞了一些前端的頁面,之后準備學一下后端。 我寫的這個項目使百度貼吧,還沒有寫er圖。 先看一下主界面是什么樣子的。 這個是主界面,將來后面的主要功能點基本上全部是放在這個上…

推動能源綠色低碳發展,風機巡檢進入國產超高清+AI時代

全球綠色低碳能源數字轉型發展正在進入一個重要窗口期。風電作為一種清潔能源,在碳中和過程中扮演重要角色,但風電場運維卻是一件十足的“苦差事”。 傳統的風機葉片人工巡檢方式主要依靠巡檢人員利用高倍望遠鏡檢查、高空繞行下降目測檢查(蜘蛛人)、葉…

STM32——Modbus協議

一、Modbus協議簡介: 1.modbus介紹: Modbus是一種串行通信協議,是Modicon公司(現在的施耐德電氣 Schneider Electric)于1979年為使用可編程邏輯控制器(PLC)通信而發表。Modbus已經成為工業領域…

PythonConda系列(親測有效):【解決方案】Collecting package metadata (current_repodata.json): failed

【解決方案】Collecting package metadata (current_repodata.json): failed 問題描述解決方案小結參考文獻 問題描述 在cmd下運行:conda install pylint -y,報錯如下: C:\Users\apr> conda install --name apr pylint -y Co…

PDF壓縮工具選哪個?6款免費PDF壓縮工具分享

PDF文件已經成為一種常見的文檔格式。然而,PDF文件的體積有時可能非常龐大,尤其是在包含大量圖像或復雜格式的情況下。選擇一個高效的PDF壓縮工具就顯得尤為重要。小編今天給大家整理了2024年6款市面上反響不錯的PDF壓縮文件工具。輕松幫助你找到最適合自…

漆包線行業生產管理革新:萬界星空科技MES系統解決方案

一、引言 在科技日新月異的今天,萬界星空科技憑借其在智能制造領域的深厚積累,為漆包線行業量身打造了一套先進的生產管理執行系統(MES)解決方案。隨著市場競爭的加劇,漆包線作為電氣設備的核心材料,其生產…

React+TS前臺項目實戰(二十四)-- 繪制組件Qrcode封裝

文章目錄 前言Qrcode組件1. 功能分析2. 代碼詳細注釋3. 使用方式4. 效果展示(pc端 / 移動端) 總結 前言 今天要封裝的Qrcode 組件,是通過傳入的信息,繪制在二維碼上,可用于很多場景,如區塊鏈項目中的區塊顯示交易地址時就可以用到…

無人值守停車場管理系統具備哪些功能?無人值守收費停車場系統多少錢

隨著城市化進程的加快,停車難已成為制約城市發展的一個突出問題。在傳統停車場管理中,人工收費、車輛登記等環節不僅效率低下,而且容易出錯。無人值守停車系統的出現,無人值守停車場系統以其高效、智能的特點,通過集成…

Meta 3D Gen:文生 3D 模型

是由 Meta 公布的一個利用 Meta AssetGen(模型生成)和 TextureGen(貼圖材質生成)的組合 AI 系統,可以在分分鐘內生成高質量 3D 模型和高分辨率貼圖紋理。 視頻演示的效果非常好,目前只有論文,期…

telegram mini app和game實現登錄功能

接上一篇文章,我們在創建好telegram機器人后,開始開發小游戲或者mini App,那就避免不了登錄功能。 公開鏈接 bot設置教程:https://lengmo714.top/6e79860b.html 參考教程參考教程,telegram已經給我們提供非常多的api,我們在獲取用…

package.json配置詳解

package.json文件 執行 npm init 命令,會在當前目錄生成一個 package.json 文件 這個文檔是你需要知道的關于你的 package.json 文件中需要什么的所有信息。它必須是實際的 JSON,而不僅僅是一個 JavaScript 對象文字。 //package.json {//如果你打算發…

使用vue動態給同一個a標簽添加內容 并給a標簽設置hover,懸浮文字變色,結果鼠標懸浮有的字上面不變色

如果Vue的虛擬DOM更新機制導致樣式更新不及時,你可以嘗試以下幾種方法來解決這個問題: 確保使用響應式數據: 確保你使用的數據是響應式的,并且任何對這些數據的更改都會觸發視圖的更新。在Vue中,你應該使用data對象中的…

多源BFS——AcWing 173. 矩陣距離

多源BFS 定義 多源BFS(多源廣度優先搜索)是一種圖遍歷算法,它是標準BFS(廣度優先搜索)的擴展,主要用于解決具有多個起始節點的最短路徑問題。在多源BFS中,不是從單一源點開始搜索整個圖&#…