MindSpore基礎教程:LeNet-5 神經網絡在MindSpore中的實現與訓練

MindSpore基礎教程:LeNet-5 神經網絡在MindSpore中的實現與訓練

官方文檔教程使用已經棄用的MindVision模塊,本文是對官方文檔的更新
深度學習在圖像識別領域取得了顯著的成功,LeNet-5 作為卷積神經網絡的經典之作,在諸多研究和應用中占有重要地位。本文將詳細介紹如何使用 MindSpore 框架實現并訓練一個 LeNet-5 神經網絡,專注于處理MNIST手寫數字數據集。

前言

MindSpore 是華為推出的一種新型深度學習框架,旨在為用戶提供高效、易用的編程體驗。接下來,我們將通過實例來展示如何在 MindSpore 中構建、訓練和評估一個經典的 LeNet-5 神經網絡。

環境配置

MindSpore官網

LeNet-5 網絡結構簡介

LeNet-5 是一個簡單的卷積神經網絡,包含兩個卷積層和三個全連接層。它經常被用于圖像識別任務,特別是在處理像 MNIST 這樣的手寫數字數據集時表現出色。

數據集準備與預處理

首先,我們需要準備并預處理數據集。在這個例子中,我們將使用 MNIST 數據集。以下函數 create_dataset 負責加載數據集,并進行必要的預處理:

def create_dataset(data_path, batch_size=32, repeat_size=1):"""創建用于訓練的MNIST數據集。此函數負責加載MNIST數據集,對數據進行預處理和轉換,以便它們可以用于訓練神經網絡。數據預處理包括調整圖像大小、重新縮放和類型轉換。參數:data_path (str): MNIST數據集的路徑。這應該是包含MNIST數據文件的目錄路徑。batch_size (int, 可選): 每個數據批次的大小。默認值為32。repeat_size (int, 可選): 數據集重復的次數。這用于增加數據集的大小。默認值為1。步驟:1. 加載MNIST數據集。2. 對圖像執行大小調整操作,將圖像大小統一調整為32x32像素。3. 對圖像進行重新縮放和標準化處理。先將像素值縮放到0-1之間,然后進行標準化。4. 將圖像的格式從高寬通道(HWC)轉換為通道高寬(CHW)。5. 對標簽進行類型轉換,將其轉換為整型(int32)。6. 對數據集進行洗牌、批處理和重復操作,以準備訓練過程。返回:返回一個處理過的MNIST數據集,可以直接用于模型訓練。注意:- 數據集的預處理步驟對于訓練深度學習模型來說是非常重要的,它們會影響訓練的效果和速度。- 調整batch_size和repeat_size可以影響模型訓練時的內存消耗和速度。"""mnist_dataset = ds.MnistDataset(data_path)resize_operation = vision.Resize((32, 32), interpolation=Inter.LINEAR)rescale_normalization_op = vision.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)rescale_op = vision.Rescale(1.0 / 255.0, 0.0)hwc_to_chw_op = vision.HWC2CHW()type_cast_op = transforms.TypeCast(mstype.int32)mnist_dataset = mnist_dataset.map(input_columns="label", operations=type_cast_op)mnist_dataset = mnist_dataset.map(input_columns="image",operations=[resize_operation, rescale_op, rescale_normalization_op,hwc_to_chw_op])mnist_dataset = mnist_dataset.shuffle(buffer_size=10000)mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)mnist_dataset = mnist_dataset.repeat(repeat_size)return mnist_dataset

這個函數將數據集中的圖像調整為統一的大小,并進行重新縮放和標準化。

構建 LeNet-5 模型

LeNet-5 模型的構建在 LeNet5 類中實現。此類定義了網絡的各層及其排列:

class LeNet5(nn.Cell):"""LeNet-5 神經網絡結構。這是一個經典的卷積神經網絡,通常用于圖像識別任務。它包含了兩個卷積層和三個全連接層。參數:num_class (int): 輸出層的類別數量。默認為10,適用于MNIST數據集。num_channel (int): 輸入圖像的通道數。對于灰度圖像,此值為1。組件:- conv1: 第一個卷積層,使用有效填充。- conv2: 第二個卷積層,同樣使用有效填充。- fc1: 第一個全連接層。- fc2: 第二個全連接層。- fc3: 第三個全連接層,輸出層。- relu: 激活函數,使用ReLU。- max_pool2d: 最大池化層。- flatten: 扁平化層,用于全連接層之前的數據轉換。方法:- construct(x): 定義了前向傳播的過程。"""def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.conv1(x)x = self.relu(x)x = self.max_pool2d(x)x = self.conv2(x)x = self.relu(x)x = self.max_pool2d(x)x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

訓練模型

接下來,我們定義 train_network 函數來訓練模型。此函數接受模型實例、數據集路徑和其他訓練參數:

def train_network(model, epoch_size, data_path, repeat_size, checkpoint_callback):"""訓練神經網絡模型。此函數負責初始化數據集,然后使用指定的模型進行訓練。在訓練過程中,它將記錄損失并保存模型的檢查點。參數:model (Model): 要訓練的神經網絡模型。epoch_size (int): 訓練過程中遍歷數據集的次數。data_path (str): 訓練數據集的路徑。repeat_size (int): 數據集的重復次數,用于擴充數據集。checkpoint_callback (Callback): 用于保存模型檢查點的回調函數。過程:- 使用 `create_dataset` 函數創建訓練數據集。- 調用模型的 `train` 方法進行訓練。- 在訓練過程中,會通過回調函數記錄損失和保存檢查點。注意:- 確保提供的 `data_path` 包含適當格式的數據。"""print("============== 開始訓練 ==============")ds_train = create_dataset(data_path, 32, repeat_size)model.train(epoch_size, ds_train, callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor()],dataset_sink_mode=False)print("============== 訓練結束 ==============")

主函數

最后,我們通過 train 函數和 parse_arguments 函數將所有步驟串聯起來。train 函數負責初始化模型、損失函數、優化器和檢查點回調,然后調用 train_network 進行訓練:

def train(args):"""初始化并訓練LeNet-5神經網絡模型。此函數設置了網絡模型、損失函數、優化器,并定義了模型檢查點。然后,使用指定的參數調用 `train_network` 函數來進行模型的訓練。參數:args (Namespace): 一個包含訓練參數的命名空間對象。此對象應該包含以下屬性:- epochs (int): 模型訓練的迭代次數。- data_url (str): 訓練數據集的路徑。- output_path (str): 保存模型檢查點的路徑。過程:1. 創建 LeNet-5 網絡實例。2. 定義損失函數為 Softmax Cross-Entropy。3. 定義優化器為 Momentum 優化器。4. 創建模型實例,并指定網絡、損失函數、優化器和評估指標。5. 設置模型檢查點配置。6. 初始化模型檢查點回調函數。7. 調用 `train_network` 函數進行訓練。注意:- 確保 `args` 對象包含正確和完整的訓練參數。- 調整優化器和損失函數的參數可以對訓練結果產生影響。- 模型檢查點將保存在 `args.output_path` 指定的路徑中。"""net = LeNet5()net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)model = Model(net, net_loss, net_opt, metrics={"Accuracy": nn.Accuracy()})config_checkpoint = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.output_path,config=config_checkpoint)train_network(model, args.epochs, args.data_url, 1, checkpoint_callback)

推理

# 加載網絡
param_dict = load_checkpoint("/root/MyCode/pycharm/lenet5/ckpt/checkpoint_lenet-19_1884.ckpt")
network = LeNet5(num_class=NUM_CLASS, num_channel=1)  # 用您定義的LeNet5類創建模型實例
load_param_into_net(network, param_dict)  # 將參數加載到網絡中
model = Model(network)def predict_digit(img):# 圖像預處理img = cv2.resize(img, (32, 32))  # 調整圖像大小為32x32img = np.array(img, dtype=np.float32)  # 轉換圖像數據類型img = (img - 0.1307) / 0.3081  # 對圖像進行標準化處理img = img[np.newaxis, np.newaxis, :, :]  # 改變圖像形狀以符合網絡輸入要求(1, 1, 32, 32)# 將圖像數據轉換為MindSpore張量img_tensor = Tensor(img)# 使用模型進行預測output = model.predict(img_tensor)# 將輸出轉換為概率分布probabilities = Softmax()(output)# 獲取每個類別的概率probabilities_np = probabilities.asnumpy()[0]# 將概率轉換為字典格式labels = [str(i) for i in range(10)]  # 類別標簽,例如"0", "1", "2", ..., "9"probabilities_dict = {label: prob for label, prob in zip(labels, probabilities_np)}return probabilities_dictgr.Interface(fn=predict_digit,inputs=gr.Image(image_mode='L'),outputs=gr.Label(num_top_classes=NUM_CLASS),live=False,css=".footer {display:none !important}",title="0-9數字畫板",description="畫0-9數字",thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png"
).launch()

結論

通過本文的指南,您可以在 MindSpore 框架中實現并訓練一個經典的 LeNet-5 神經網絡。LeNet-5 在圖像識別任務中展現了卓越的性能,而 MindSpore 的高效和易用性使得深度學習研究和開發更加便捷。您可以根據本文的指導進行實驗,并根據需要調整網絡結構和訓練參數。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/165273.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/165273.shtml
英文地址,請注明出處:http://en.pswp.cn/news/165273.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux | 從虛擬地址到物理地址

前言 本章主要講解虛擬地址是怎么轉化成物理地址的,以及頁表相關知識;本文環境默認為32位機器下;如果你連什么是虛擬地址都不知道可以先看看下面這篇文章; Linux | 進程地址空間-CSDN博客 一、概念補充 頁表:是一種數據…

【性能優化】CPU利用率飆高與內存飆高問題

📫作者簡介:小明java問道之路,2022年度博客之星全國TOP3,專注于后端、中間件、計算機底層、架構設計演進與穩定性建設優化,文章內容兼具廣度、深度、大廠技術方案,對待技術喜歡推理加驗證,就職于…

2023APMCM亞太杯數學建模選題建議及初步思路

大家好呀,亞太杯數學建模開始了,來說一下初步的選題建議吧: 首先定下主基調,本次亞太杯推薦選擇B題。 C題如果想做好,搜集數據難度并不低,并且模型比較簡單,此外目前選擇的人數過多&#xff0c…

java項目之消防物資存儲系統(ssm+vue)

項目簡介 消防物資存儲系統實現了以下功能: 管理員功能: 管理員登陸后,主要模塊包括首頁,個人中心,用戶管理,倉庫管理,物資入庫管理,物資出庫管理,倉庫管理,物資詳情管…

23年下半年軟考成績查詢時間是什么時候?

一、成績查詢時間 2023年下半年軟考成績查詢時間預計2023年12月份公布,成績查詢入口為計算機技術職業資格網(全國統一成績查詢時間,統一查詢入口)。 二、成績查詢方法 登陸中國計算機技術職業資格網,點擊“成績查詢”…

7-9 jmu-python-班級人員信息統計

7-9 jmu-python-班級人員信息統計 分數 15 作者 鄭如濱 單位 集美大學 輸入a,b班的名單,并進行如下統計。 輸入格式: 第1行::a班名單,一串字符串,每個字符代表一個學生,無空格,可能有重復字符。 第2行:&am…

WPF實戰項目十六(客戶端):備忘錄接口

1、新增IMemoService接口&#xff0c;繼承IBaseService接口 public interface IMemoService : IBaseService<MemoDto>{} 2、新增MemoService類&#xff0c;繼承BaseService和IMemoService接口 public class MemoService : BaseService<MemoDto>, IMemoService{pub…

DRF-通用分頁器(PageNumberPagination):ListModelMixin可以使用的通用分頁器

一、ListModelMixin 和GenericAPIView源碼 ListModelMixin 是一個單一功能類&#xff0c;必須配合GenericAPIView&#xff08;或其子類&#xff09;來一起使用&#xff0c;才能完成其視圖的功能 class ListModelMixin:"""List a queryset."""d…

騰訊云點播小程序端上傳 SDK

云點播是專門應對上傳大視頻文件的。 騰訊云點播文檔&#xff1a;https://cloud.tencent.com/document/product/266/18177 這個文檔比較簡單&#xff0c;實在不行&#xff0c;把demo下載下來&#xff0c;一看就明白了&#xff0c;然后再揉一下挪到自己的項目里。完事。 getSign…

芯知識 | 混音播報語音芯片的優勢:革新音頻應用的新力量

隨著科技的進步&#xff0c;語音芯片在各個領域的應用越來越廣泛。而在眾多語音芯片中&#xff0c;混音播報語音芯片以其獨特的優勢&#xff0c;正逐漸成為音頻應用領域的翹楚。本文將重點探討混音播報語音芯片的優勢及其在現代科技應用中的價值。 一、混音播報語音芯片概述 …

element-vue實現網頁鎖屏功能

1.寫一個鎖屏頁面&#xff0c;這里比較簡單&#xff0c;自己定義一下,需要放到底層HTML中哦&#xff0c;比如index.html <div id"appIndex"><el-dialog title"請輸入密碼解鎖屏幕" :visible.sync"lockScreenFlag" :close-on-click-mod…

力扣236. 二叉樹的最近公共祖先(java DFS解法)

Problem: 236. 二叉樹的最近公共祖先 文章目錄 題目描述思路解題方法復雜度Code 題目描述 給定一個二叉樹, 找到該樹中兩個指定節點的最近公共祖先。 百度百科中最近公共祖先的定義為&#xff1a;“對于有根樹 T 的兩個節點 p、q&#xff0c;最近公共祖先表示為一個節點 x&am…

Android逆向一-frida操作

系列文章目錄 第一章 frida操作 文章目錄 系列文章目錄前言一、兩種模式二、frida命令行執行及參數三、frida使用python執行四、動靜態域調用1. 靜態域調用2.動態域調用 五. 遠程rpc調用六. 補充總結 前言 熟悉frida操作&#xff0c;hook手機app的關鍵位置進行逆向操作 一、…

芯知識 | Flash可更換聲音語音芯片—引領音頻IC技術革新的新篇章

隨著科技的飛速發展&#xff0c;人們對于電子產品的音頻性能要求越來越高。在這種背景下&#xff0c;Flash可更換聲音語音芯片應運而生&#xff0c;成為音頻技術領域的一顆璀璨明星。本文將詳細介紹Flash可更換聲音語音芯片的特點、優勢以及應用場景&#xff0c;展望其在未來科…

【Docker】從零開始:10.registry搭建私有倉庫

【Docker】從零開始&#xff1a;10.registry搭建私有倉庫 為什么要使用私有倉庫關于Docker Registry基于容器搭建registry私有倉庫1.下載鏡像2. 啟動鏡像3.修改系統配置文件4.下載ubuntu鏡像&#xff0c;修改名稱3.提交鏡像4.查看鏡像 本地搭建私有倉庫(目前編譯報錯找不到包&a…

【管理運籌學】背誦手冊(五)| 動態規劃

五、動態規劃 基本概念 階段&#xff08;Stage&#xff09;&#xff1a;將所給問題的過程&#xff0c;按時間或空間特征分解成若干相互聯系的階段&#xff0c;以便按次序去求解每階段的解&#xff0c;常用字母 k k k 表示。 狀態&#xff08;State&#xff09;&#xff1a;…

java實現連接linux(上傳文件,執行shell命令等)

1 導入pom <dependency><groupId>com.jcraft</groupId><artifactId>jsch</artifactId><version>0.1.55</version></dependency> 2 編寫配置類 package com.budwk.app.atest;import com.budwk.app.common.config.AppExceptio…

計算機網絡之網絡層

一、概述 主要任務是實現網絡互連&#xff0c;進而實現數據包在各網絡之間的傳輸 1.1網絡引入的目的 從7層結構上看&#xff0c;網絡層下是數據鏈路層 從4層結構上看&#xff0c;網絡層下面是網絡接口層 至少我們看到的網絡層下面是以太網 以太網解決了什么問題&#xff1f; 答…

【Python 千題 —— 基礎篇】刪除列表值

題目描述 題目描述 刪除列表的指定值。有一個列表 [1, 3, 5, 2, 44, 1, 9, 10, 32] &#xff0c;請使用 for 循環刪除該列表中與 [44, 1, 9] 列表相同的值&#xff0c;并輸出該列表。 輸入描述 無輸入。 輸出描述 輸出操作后的列表。 示例 示例 ① 輸出&#xff1a; …

記錄:通過day.js獲取兩個日期相差的時間,并轉化為年月日的格式

day.js這個日期庫真的是很不錯的日期庫&#xff0c;足夠滿足日常的開發需求。 Day.js中文網 (fenxianglu.cn) 需求&#xff1a;獲取兩個日期相差的時間&#xff0c;轉化為年月日的形式&#xff1b;話不多少&#xff0c;直接放代碼 import dayjs from "dayjs"; imp…