昇思25天學習打卡營第2天|MindSpore快速入門

打卡

目錄

打卡

快速入門案例:minist圖像數據識別任務

案例任務說明

流程

1 加載并處理數據集

2 模型網絡構建與定義

3 模型約束定義

4 模型訓練

5 模型保存

6 模型推理

相關參考文檔入門理解

MindSpore數據處理引擎

模型網絡參數初始化

模型優化器

損失函數

代碼

安裝

從模型訓練到預測推理

self_main_train_and_save.py

self_dataprocess.py

self_network.py

self_modeltrain.py

self_modeltest.py

self_predict.py


快速入門案例:minist圖像數據識別任務

案例任務說明

MINIST數據集是有標簽的圖像數據,圖像數據是0-9的手寫阿拉伯數字。其中,訓練集有6W個,測試集1W個。

目的是訓練一個可以高效識別手寫阿拉伯數字的模型。

流程

1 加載并處理數據集

涉及到的mindspore接口 mindspore.dataset。例如對數據集的map、batch、shuffle等操作,數據列名獲取,對數據集進行迭代訪問、查看數據和標簽的shape和datatype等。

2 模型網絡構建與定義

涉及到 mindspore.nn 類。例如用戶可繼承nn.Cell類來自定義網絡結構,其中的construct類函數包含數據(Tensor)的變換過程。。

3 模型約束定義

包括損失函數、優化器等。如?nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)

4 模型訓練

- 定義訓練函數,用set_train設置為訓練模式,執行正向計算、反向傳播和參數優化。

- 定義測試函數,用來評估模型的性能。

5 模型保存

- 兩種保存方式:

1)模型參數保存:mindspore.save_checkpoint(model, "model.ckpt")

2)統一的中間表示(Intermediate Representation,IR)的保存,MindIR同時保存了Checkpoint和模型結構,因此需要定義輸入Tensor來獲取輸入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

6 模型推理

- 兩種加載方式:

1)模型參數加載:?

> model = network()

> param_dict = mindspore.load_checkpoint("model.ckpt");??

>?param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

2)統一的中間表示(Intermediate Representation,IR)的加載:

> mindspore.set_context(mode=mindspore.GRAPH_MODE)
> graph = mindspore.load("model.mindir")
> model = nn.GraphCell(graph)  ## nn.GraphCell 僅支持圖模式。
> outputs = model(inputs)

保存與加載 — MindSpore master 文檔

相關參考文檔入門理解

MindSpore數據處理引擎

MindSpore 通過對外暴露API層來構建數據圖;內部的Data Processing Pipeline 層用來進行數據加載和預處理多步并行流水線。
高性能數據處理引擎 — MindSpore master 文檔

MindSpore 通過數據集(Dataset)和數據變換(Transforms)實現高效的數據預處理。

數據集 Dataset — MindSpore master 文檔

數據變換 Transforms — MindSpore master 文檔

模型網絡參數初始化

Initializer是MindSpore內置的參數初始化基類,所有內置參數初始化方法均繼承該類。mindspore.nn中提供的神經網絡層封裝均提供weight_initbias_init等入參,可以直接使用實例化的Initializer進行參數初始化。

參數初始化 — MindSpore master 文檔

模型優化器

優化器 — MindSpore master 文檔

損失函數

損失函數 — MindSpore master 文檔

代碼

安裝

pip/conda均可:

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

從模型訓練到預測推理

訓練:

python self_main_train_and_save.py

推理:

python self_predict.py

self_main_train_and_save.py

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download庫從公開華為云obs桶下載 MINIST 數據集并解壓。因為mindspore.dataset 提供的接口僅支持解壓后的數據文件 
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)    ## 1 加載數據集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names())   # 打印數據集中包含的數據列名,用于dataset的預處理。輸出['image', 'label']## 2 MindSpore的dataset使用數據處理流水線,這里將處理好的數據集打包為大小為64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)  
test_dataset = datapipe(test_dataset, 64)  ## 3 數據集加載后,一般以迭代方式獲取數據,然后送入神經網絡中進行訓練。可使用create_tuple_iterator 或create_dict_iterator對數據集進行迭代訪問,查看數據和標簽的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型訓練
from self_network import Network
from self_modeltrain import train, loss_fn 
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

self_dataprocess.py

from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset

self_network.py

# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)

self_modeltrain.py

# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()     ## 設置當前Cell和所有子Cell的訓練模式。對于訓練和預測具有不同結構的網絡層(如 BatchNorm),將通過這個屬性區分分支。如果設置為True,則執行訓練分支,否則執行另一個分支。默認Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

self_modeltest.py

from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

self_predict.py

## 加載模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)   ## param_not_load是未被加載的參數列表,為空時代表所有參數均加載成功。## 加載后的模型可以直接用于預測推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/40875.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/40875.shtml
英文地址,請注明出處:http://en.pswp.cn/web/40875.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

一個字符串的全部子序列和全排列

在計算機科學中,字符串的子序列和全排列是兩個重要的概念。 1. 子序列 子序列是從一個序列中刪除一些(或不刪除)元素而不改變剩余元素的順序形成的新序列。 例如,字符串 “abc” 的子序列包括: “”(空…

如何選擇TikTok菲律賓直播網絡?

為了滿足用戶對于實時互動的需求,TikTok推出了直播功能,讓用戶能夠與粉絲即時交流。本文將探討如何選擇適合的TikTok菲律賓直播網絡,并分析OgLive是否是值得信賴的選擇。 TikTok菲律賓直播網絡面臨的挑戰 作為全球領先的短視頻平臺&#xff…

Python + OpenCV 開啟圖片、寫入儲存圖片

這篇教學會介紹OpenCV 里imread()、imshow()、waitKey() 方法,透過這些方法,在電腦中使用不同的色彩模式開啟圖片并顯示圖片。 imread() 開啟圖片 使用imread() 方法,可以開啟圖片,imread() 有兩個參數,第一個參數為檔…

Google Play上架:惡意軟件、移動垃圾軟件和行為透明度詳細解析和解決辦法 (一)

近期整理了許多開發者的拒審郵件和內容,也發現了許多問題,今天來說一下關于惡意軟件這類拒審的問題。 目標郵件如下: 首先說一下各位小伙伴留言私信的一個方法,提供你的拒審郵件和時間,盡可能的詳細,這樣會幫助我們的團隊了解你們的問題,去幫助小伙伴么解決問題。由于前…

在 .NET 8 Web API 中實現彈性

在現代 Web 開發中,構建彈性 API 對于確保可靠性和性能至關重要。本文將指導您使用 Microsoft.Extensions.Http.Resilience 庫在 .NET 8 Web API 中實現彈性。我們將介紹如何設置重試策略和超時,以使您的 API 更能抵御瞬時故障。 步驟 1.創建一個新的 .…

集成學習(一)Bagging

前邊學習了:十大集成學習模型(簡單版)-CSDN博客 Bagging又稱為“裝袋法”,它是所有集成學習方法當中最為著名、最為簡單、也最為有效的操作之一。 在Bagging集成當中,我們并行建立多個弱評估器(通常是決策…

排序——數據結構與算法 總結8

目錄 8.1 排序相關概念 8.2 插入排序 8.2.1 直接插入排序: 8.2.2 折半插入排序: 8.2.3 希爾排序: 8.3 交換排序 8.3.1 冒泡排序: 8.3.2 快速排序: 8.4 選擇排序 8.4.1 簡單選擇排序 8.4.2 堆排序 8.5 歸并…

磁盤就是一個超大的Byte數組,操作系統是如何管理的?

磁盤在操作系統的維度看,就是一個“超大的Byte數組”。 那么操作系統是如何對這塊“超大的Byte數組”做管理的呢? 我們知道在邏輯上,上帝說是用“文件”的概念來進行管理的。于是,便有了“文件系統”。那么,文件系統…

當前國內可用的docker加速器搜集 —— 筑夢之路

可用鏡像加速器 以下地址搜集自網絡,僅供參考,請自行驗證。 1、https://docker.m.daocloud.io2、https://dockerpull.com3、https://atomhub.openatom.cn4、https://docker.1panel.live5、https://dockerhub.jobcher.com6、https://hub.rat.dev7、http…

最新版情侶飛行棋dofm,已解鎖高階私密模式,單身狗務必繞道!(附深夜學習資源)

今天阿星要跟大家聊一款讓阿星這個大老爺們兒面紅耳赤的神奇游戲——情侶飛行棋。它的神奇之處就在于專為情侶設計,能讓情侶之間感情迅速升溫,但單身狗們請自覺繞道,不然后果自負哦! 打開游戲,界面清新,操…

HTML5使用<progress>進度條、<meter>刻度條

1、<progress>進度條 定義進度信息使用的是 progress 標簽。它表示一個任務的完成進度&#xff0c;這個進度可以是不確定的&#xff0c;只是表示進度正在進行&#xff0c;但是不清楚還有多少工作量沒有完成&#xff0c;也可以用0到某個最大數字&#xff08;如&#xff1…

vs2022安裝qt vs tool

1 緣由 由于工作的需要&#xff0c;要在vs2022上安裝qt插件進行開發。依次安裝qt&#xff0c;vs2022&#xff0c;在vs2022的擴展管理中安裝qt vs tool。 2 遇到困難 問題來了&#xff0c;在qt vs tool的設置qt version中出現問題&#xff0c;設置msvc_64-bit時出現提示“invali…

西安石油大學 課程習題信息管理系統(數據庫課設)

主要技術棧 Java Mysql SpringBoot Tomcat HTML CSS JavaScript 該課設必備環境配置教程&#xff1a;&#xff08;參考給出的鏈接和給出的關鍵鏈接&#xff09; JAVA課設必備環境配置 教程 JDK Tomcat配置 IDEA開發環境配置 項目部署參考視頻 若依框架 鏈接數據庫格式注…

【中項第三版】系統集成項目管理工程師 | 第 4 章 信息系統架構① | 4.1-4.2

前言 第4章對應的內容選擇題和案例分析都會進行考查&#xff0c;這一章節屬于技術相關的內容&#xff0c;學習要以教材為準。本章分值預計在4-5分。 目錄 4.1 架構基礎 4.1.1 指導思想 4.1.2 設計原則 4.1.3 建設目標 4.1.4 總體框架 4.2 系統架構 4.2.1 架構定義 4.…

Invoice OCR

Invoice OCR 發票識別 其他類型ORC&#xff1a; DIPS_YTPC OCR-CSDN博客

25款404網頁源碼(上)

25款404網頁源碼&#xff08;上&#xff09; 1部分源碼 2部分源碼 3部分源碼 4部分源碼 5部分源碼 6部分源碼 7部分源碼 8部分源碼 9部分源碼 10部分源碼 11部分源碼 12部分源碼 領取完整源碼下期更新 1 部分源碼 <!DOCTYPE html> <html><!-- 優選源碼 gulang.…

數據結構基礎--------【二叉樹基礎】

二叉樹基礎 二叉樹是一種常見的數據結構&#xff0c;由節點組成&#xff0c;每個節點最多有兩個子節點&#xff0c;左子節點和右子節點。二叉樹可以用來表示許多實際問題&#xff0c;如計算機程序中的表達式、組織結構等。以下是一些二叉樹的概念&#xff1a; 二叉樹的深度&a…

Element-UI - el-table中自定義圖片懸浮彈框 - 位置優化

該篇為前一篇“Element-UI - 解決el-table中圖片懸浮被遮擋問題”的優化升級部分&#xff0c;解決當圖片位于頁面底部時&#xff0c;顯示不全問題優化。 Vue.directive鉤子函數已在上一篇中詳細介紹&#xff0c;不清楚的朋友可以翻看上一篇&#xff0c; “Element-UI - 解決el-…

深入刨析Redis存儲技術設計藝術(二)

三、Redis主存儲 3.1、存儲相關結構體 redisServer:服務器 server.h struct redisServer { /* General */ pid_t pid; /* Main process pid. */ pthread_t main_thread_id; /* Main thread id */ char *configfile; /* Absolut…

Interpretability 與 Explainability 機器學習

「AI秘籍」系列課程&#xff1a; 人工智能應用數學基礎人工智能Python基礎人工智能基礎核心知識人工智能BI核心知識人工智能CV核心知識 Interpretability 模型和 Explainability 模型之間的區別以及為什么它可能不那么重要 當你第一次深入可解釋機器學習領域時&#xff0c;你會…