昇思25天學習打卡營第7天之二 | 模型保存與加載

1. 保存與加載

在訓練網絡模型的過程中,實際上我們希望保存中間和最后的結果,用于微調(fine-tune)和后續的模型推理與部署,本章節我們將介紹如何保存與加載模型。

1.1 導入依賴

# 導入numpy庫,并將其重命名為np,以便在代碼中引用
import numpy as np# 導入MindSpore庫,這是華為推出的一個開源深度學習框架,用于構建和訓練神經網絡
import mindspore# 從MindSpore庫中導入nn模塊,這個模塊包含了構建神經網絡所需的各種層和函數
from mindspore import nn# 從MindSpore庫中導入Tensor模塊,Tensor是MindSpore中用于表示張量的類
from mindspore import Tensor

1.1定義神經網絡模型

# 定義一個函數,該函數創建一個簡單的全連接神經網絡模型
def network():# 使用nn.SequentialCell創建一個層序列,這是一個容器類,可以包含多個層model = nn.SequentialCell(# 第一個層是一個Flatten層,用于將輸入的二維圖像數據展平為一維向量nn.Flatten(),# 第二個層是一個全連接層,將28x28的輸入節點映射到512個節點nn.Dense(28*28, 512),# 第三個層是一個ReLU激活函數,用于非線性變換nn.ReLU(),# 第四個層是一個全連接層,將512個節點映射到512個節點nn.Dense(512, 512),# 第五個層是一個ReLU激活函數,用于非線性變換nn.ReLU(),# 第六個層是一個全連接層,將512個節點映射到10個節點,對應于10個類別的輸出nn.Dense(512, 10))# 返回創建好的模型return model

1.2 保存和加載模型權重

1.2.1 保存模型

保存模型使用save_checkpoint接口,傳入網絡和指定的保存路徑:

# 創建一個神經網絡模型實例
model = network()# 使用MindSpore的save_checkpoint函數將模型的檢查點保存到文件
# 第一個參數是模型對象
# 第二個參數是文件名,這里保存為"model.ckpt"
mindspore.save_checkpoint(model, "model.ckpt")
# 打印模型結構
print(model)

輸出:

SequentialCell<(0): Flatten<>(1): Dense<input_channels=784, output_channels=512, has_bias=True>(2): ReLU<>(3): Dense<input_channels=512, output_channels=512, has_bias=True>(4): ReLU<>(5): Dense<input_channels=512, output_channels=10, has_bias=True>>

模型大小估算:
model_capacity ≈ 模型參數 * 數據精度(默認是int32類型)大小 = [(784512+512) + (512512+512) + (512*10 +10)] *32bit/8(bit/Byte)= 669704 *4 = 2678824 Byte
可以看到,模型參數量約為67W,占用空間大小應約為2678824字節
實際該模型文件大小為2679017。可以說非常接近了,剩下的字節應該就是文件類型描述符加模型結構描述符之類的內容了。
所以當我們已知一個模型的參數量和參數精度后,實際就可以估算出模型占用的磁盤空間大小了。

1.2.2 加載模型

要加載模型權重,需要先創建相同模型的實例,然后使用load_checkpointload_param_into_net方法加載參數。

# 創建一個神經網絡模型
model = network()# 使用MindSpore的load_checkpoint函數從文件中加載模型的參數和優化器狀態
# 參數是檢查點的文件名,這里加載的文件名為"model.ckpt"
param_dict = mindspore.load_checkpoint("model.ckpt")# 使用MindSpore的load_param_into_net函數將加載的參數字典加載到模型中
# 第一個參數是模型對象
# 第二個參數是參數字典
# 返回值是一個元組,第一個元素是未加載的參數列表,第二個元素是加載的參數列表
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)# 打印未加載的參數列表,如果加載成功,這個列表應該是空的
print(param_not_load)

輸出:

[]

param_not_load是未被加載的參數列表,為空時代表所有參數均加載成功。

1.3 保存和加載MindIR

除Checkpoint外,MindSpore提供了云側(訓練)和端側(推理)統一的中間表示(Intermediate Representation,IR)。可使用export接口直接將模型保存為MindIR。

# 創建網絡模型
model = network()
# 創建一個Tensor對象,它包含一個大小為[1, 1, 28, 28]的矩陣,所有元素都是1,數據類型為float32
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
# 使用mindspore.export函數將模型導出為MINDIR格式
# 第一個參數是模型對象
# 第二個參數是輸入數據,這里使用了一個Tensor對象作為示例
# 第三個參數是文件名,這里導出的文件名為"model"
# 第四個參數是文件格式,這里設置為"MINDIR",表示導出的模型格式
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同時保存了Checkpoint和模型結構,因此需要定義輸入Tensor來獲取輸入shape。

已有的MindIR模型可以方便地通過load接口加載,傳入nn.GraphCell即可進行推理。

nn.GraphCell僅支持圖模式。

# 設置MindSpore的執行模式為GRAPH_MODE
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# 加載之前導出的MINDIR模型
graph = mindspore.load("model.mindir")
# 創建一個GraphCell對象,它將graph作為其成員
model = nn.GraphCell(graph)
# 使用模型對輸入數據進行前向計算,得到輸出
outputs = model(inputs)
# 打印輸出的形狀
print(outputs.shape)

輸出:
模型加載f

2. 小結

本文主要介紹了模型的保存和加載,都包括檢查點checkpoint和統一中間表示MindIR(Intermediate Representation)兩種方法,還介紹了模型大小的估算方法。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/37162.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/37162.shtml
英文地址,請注明出處:http://en.pswp.cn/web/37162.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

六月,允許自己做自己,別人做別人

今天結束后&#xff0c;2024 就過去一半了。 年初的規劃完成一半了嗎&#xff1f;如果沒有也沒關系&#xff0c;做你自己繼續前進。 家人來北京旅游&#xff0c;我累趴了 六月初&#xff0c;我搬家了&#xff0c;這次租了一整套房&#xff0c;是一個小倆居、還帶一個小閣樓。…

速盾:視頻cdn和網站cdn的相同點與不同點

CDN&#xff08;Content Delivery Network&#xff09;是一種分布式網絡架構&#xff0c;旨在為用戶提供高效、高質量的內容傳送服務。CDN主要通過將內容分發到全球各地的邊緣節點&#xff0c;并根據用戶的地理位置選擇最近的節點來提供內容&#xff0c;從而加速內容的傳輸并降…

【高考志愿】儀器科學與技術

目錄 一、專業介紹 1.1 專業概述 1.2 專業方向 1.3 主要課程 二、專業技能與素質培養 三、就業前景 四、個人發展規劃建議 五、儀器科學與技術專業排名 六、總結 一、專業介紹 1.1 專業概述 儀器科學與技術專業是一門綜合性極強的學科&#xff0c;它融合了測量、控制…

數學學習與研究雜志社《數學學習與研究》雜志社2024年第6期目錄

課改前沿 基于核心素養的高中數學課堂教學研究——以“直線與圓、圓與圓的位置關系”為例 張亞紅; 2-4 核心素養視角下初中生數學閱讀能力的培養策略探究 賈象虎; 5-7 初中數學大單元教學實踐策略探索 耿忠義; 8-10《數學學習與研究》投稿&#xff1a;cn7kantougao…

使用Python繪制極坐標圖

使用Python繪制極坐標圖 極坐標圖極坐標圖的優點使用場景 效果代碼 極坐標圖 極坐標圖&#xff08;Polar Chart&#xff09;是一種圖表類型&#xff0c;用于顯示在極坐標系中的數據。極坐標圖使用圓形坐標系&#xff0c;角度表示一個變量的值&#xff0c;半徑表示另一個變量的…

線程安全問題(二)——死鎖

死鎖 前言可重入鎖邏輯 兩個線程兩把鎖&#xff08;死鎖&#xff09;死鎖的特點多個線程多把鎖&#xff08;哲學家就餐問題&#xff09;總結 前言 在前面的文章中&#xff0c;介紹了鎖的基本使用方式——鎖 在上一篇文章中&#xff0c;通過synchronized關鍵字進行加鎖操作&am…

XML簡介XML 使用教程XML的基本結構XML的使用場景

學習總結 1、掌握 JAVA入門到進階知識(持續寫作中……&#xff09; 2、學會Oracle數據庫入門到入土用法(創作中……&#xff09; 3、手把手教你開發炫酷的vbs腳本制作(完善中……&#xff09; 4、牛逼哄哄的 IDEA編程利器技巧(編寫中……&#xff09; 5、面經吐血整理的 面試技…

VMware每次打開網絡設置都出現需要運行NetworkManager問題

每次打開都出現這個情況&#xff0c;是因為之前把NetworkManager服務服務關閉&#xff0c;重新輸入命令&#xff1a; sudo systemctl start NetworkManager.service或者 sudo service network-manager restart 即可解決&#xff0c;但是每次開機重啟都要打開就很麻煩&#xf…

【Chapter4】匯編語言及其程序設計,《微機系統》第一版,趙宏偉

一、匯編語言概述 **指令&#xff1a;**指使計算機完成某種操作的命令。 **程序&#xff1a;**完成某種功能的指令序列。 **軟件&#xff1a;**各種程序總稱。 **機器語言&#xff1a;**計算機能直接識別的語言。用機器語言寫出的程序稱為機器代碼。 **匯編語言&#xff1…

Forecasting from LiDAR via Future Object Detection

Forecasting from LiDAR via Future Object Detection 基礎信息 論文&#xff1a;cvpr2022paper https://openaccess.thecvf.com/content/CVPR2022/papers/Peri_Forecasting_From_LiDAR_via_Future_Object_Detection_CVPR_2022_paper.pdfgithub&#xff1a;https://github.co…

SyncUnsafeCell替換Mutex提高性能

1. 背景 在Rust開發過程中&#xff0c;很多情況下需要在不可變的情況下獲取可變性或者在多線程的情況下可以安全的貢獻可變數據。這種情況下我們一般使用**Mutex來實現通過加鎖來實現。現在我們可以通過使用SyncUnsafeCell來替代Mutex**。 2. SyncUnsafeCell SyncUnsafeCell…

【計算機網絡——1.1網絡internet】

網絡 就是用通信線路和通信設備把很多個“主機/端設備“相互聯系。然后按照某種溝通方式&#xff0c;專業術語叫“協議”&#xff0c;共享信息。 **&#xff08; 計算機網絡&#xff1a;節點和邊構成的系統 節點&#xff1a; 主機節點&#xff1a;主機/端設備(手機&#x…

K8S之網絡深度剖析(一)(持續更新ing)

K8S之網絡深度剖析 一 、關于K8S的網絡模型 在K8s的世界上,IP是以Pod為單位進行分配的。一個Pod內部的所有容器共享一個網絡堆棧(相當于一個網絡命名空間,它們的IP地址、網絡設備、配置等都是共享的)。按照這個網絡原則抽象出來的為每個Pod都設置一個IP地址的模型也被稱作為I…

SpringBoot(一)創建一個簡單的SpringBoot工程

Spring框架常用注解簡單介紹 SpringMVC常用注解簡單介紹 SpringBoot&#xff08;一&#xff09;創建一個簡單的SpringBoot工程 SpringBoot&#xff08;二&#xff09;SpringBoot多環境配置 SpringBoot&#xff08;三&#xff09;SpringBoot整合MyBatis SpringBoot&#xff08;四…

3.ROS串口實例

#include <iostream> #include <ros/ros.h> #include <serial/serial.h> #include<geometry_msgs/Twist.h> using namespace std;//運行打開速度控制插件&#xff1a; rosrun rqt_robot_steering rqt_robot_steering //若串口訪問權限不夠&#xff1a…

詳解PEFT庫中LoRA源碼

前言 GitHub項目地址Some-Paper-CN。本項目是譯者在學習長時間序列預測、CV、NLP和機器學習過程中精讀的一些論文&#xff0c;并對其進行了中文翻譯。還有部分最佳示例教程。如果有幫助到大家&#xff0c;請幫忙點亮Star&#xff0c;也是對譯者莫大的鼓勵&#xff0c;謝謝啦~本…

讀書筆記-《Spring技術內幕》(三)MVC與Web環境

前面我們學習了 Spring 最核心的 IoC 與 AOP 模塊&#xff08;讀書筆記-《Spring技術內幕》&#xff08;一&#xff09;IoC容器的實現、讀書筆記-《Spring技術內幕》&#xff08;二&#xff09;AOP的實現&#xff09;&#xff0c;接下來繼續學習 MVC&#xff0c;其同樣也是經典…

Spring底層原理之bean的加載方式八 BeanDefinitionRegistryPostProcessor注解

BeanDefinitionRegistryPostProcessor注解 這種方式和第七種比較像 要實現兩個方法 第一個方法是實現工廠 第二個方法叫后處理bean注冊 package com.bigdata1421.bean;import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.…

解決idea中git無法管理項目中所有需要管理的文件

點擊文件->設置 選擇版本控制—>目錄映射 點擊加號 設置整個項目被Git管理

【python入門】自定義函數

文章目錄 定義自定義函數的基本語法參數類型示例代碼函數作用域匿名函數&#xff08;Lambda&#xff09;閉包裝飾器 Python中的自定義函數允許你編寫一段可重用的代碼塊&#xff0c;這段代碼可以帶參數&#xff08;輸入&#xff09;&#xff0c;并可能返回一個值&#xff08;輸…