昇思25天學習打卡營第7天 | 模型訓練

內容介紹:

模型訓練一般分為四個步驟:

1. 構建數據集。
2. 定義神經網絡模型。
3. 定義超參、損失函數及優化器。
4. 輸入數據集進行訓練與評估。

具體內容:

1. 導包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from download import download

2. 構建數據集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

3. 定義神經網絡模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()

4.?定義超參、損失函數和優化器

超參(Hyperparameters)是可以調整的參數,可以控制模型訓練優化的過程,不同的超參數值可能會影響模型訓練和收斂速度。目前深度學習模型多采用批量隨機梯度下降算法進行優化。

訓練輪次(epoch):訓練時遍歷數據集的次數。

批次大小(batch size):數據集進行分批讀取訓練,設定每個批次數據的大小。batch size過小,花費時間多,同時梯度震蕩嚴重,不利于收斂;batch size過大,不同batch的梯度方向沒有任何變化,容易陷入局部極小值,因此需要選擇合適的batch size,可以有效提高模型精度、全局收斂。

學習率(learning rate):如果學習率偏小,會導致收斂的速度變慢,如果學習率偏大,則可能會導致訓練不收斂等不可預測的結果。梯度下降法被廣泛應用在最小化模型誤差的參數優化算法上。梯度下降法通過多次迭代,并在每一步中最小化損失函數來預估模型的參數。學習率就是在迭代過程中,會控制模型的學習進度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

損失函數(loss function)用于評估模型的預測值(logits)和目標值(targets)之間的誤差。訓練模型時,隨機初始化的神經網絡模型開始時會預測出錯誤的結果。損失函數會評估預測結果與目標值的相異程度,模型訓練的目標即為降低損失函數求得的誤差。

常見的損失函數包括用于回歸任務的`nn.MSELoss`(均方誤差)和用于分類的`nn.NLLLoss`(負對數似然)等。 `nn.CrossEntropyLoss` 結合了`nn.LogSoftmax`和`nn.NLLLoss`,可以對logits 進行歸一化并計算預測誤差。

loss_fn = nn.CrossEntropyLoss()

模型優化(Optimization)是在每個訓練步驟中調整模型參數以減少模型誤差的過程。MindSpore提供多種優化算法的實現,稱之為優化器(Optimizer)。優化器內部定義了模型的參數優化過程(即梯度如何更新至模型參數),所有優化邏輯都封裝在優化器對象中。在這里,我們使用SGD(Stochastic Gradient Descent)優化器。

我們通過`model.trainable_params()`方法獲得模型的可訓練參數,并傳入學習率超參來初始化優化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

在訓練過程中,通過微分函數可計算獲得參數對應的梯度,將其傳入優化器中即可實現參數優化,具體形態如下:

grads = grad_fn(inputs)

optimizer(grads)

5. 訓練與評估

設置了超參、損失函數和優化器后,我們就可以循環輸入數據來訓練模型。一次數據集的完整迭代循環稱為一輪(epoch)。每輪執行訓練時包括兩個步驟:

1. 訓練:迭代訓練數據集,并嘗試收斂到最佳參數。
2. 驗證/測試:迭代測試數據集,以檢查模型性能是否提升。
?

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

MindSpore的易用性也給我帶來了很大的便利。通過簡潔明了的API和豐富的文檔支持,我能夠快速地掌握MindSpore的使用方法,并輕松地構建自己的深度學習模型。同時,MindSpore還提供了豐富的預訓練模型和示例代碼,讓我能夠更快地入門并深入理解深度學習的應用。

在模型訓練的過程中,我深刻體會到了深度學習模型的復雜性和挑戰性。通過不斷地調整網絡結構、優化參數設置以及嘗試不同的訓練策略,我逐漸掌握了如何構建和訓練一個性能優異的深度學習模型。這個過程讓我更加明白了深度學習模型訓練需要耐心、細致和持續的努力。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/35458.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/35458.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/35458.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

手把手教你使用kimi創建流程圖【實踐篇】

學境思源,一鍵生成論文初稿: AcademicIdeas - 學境思源AI論文寫作 引言 在昨日的文章中,我們介紹了如何使用Kimi生成論文中的流程圖。今天,我們將更進一步,通過實踐案例來展示Kimi在生成流程圖方面的應用。這不僅將加…

【大數據技術原理與應用(概念、存儲、處理、分析與應用)】第1章-大數據概述習題與知識點回顧

文章目錄 單選題多選題知識點回顧幾次信息化浪潮主要解決什么問題?信息科技為大數據時代提供哪些技術支撐?數據產生方式有哪些變革?大數據的發展歷程大數據的四個特點(4V)大數據對思維方式的影響大數據有哪些關鍵技術&…

burpsuite 抓https的方法(CA證書操作)

https://cloud.tencent.com/developer/article/1391501

軟考《信息系統運行管理員》-1.2信息系統運維

1.2信息系統運維 傳統運維模式(軟件) 泛化:軟件交付后圍繞其所做的任何工作糾錯:軟件運行中錯誤的發現和改正適應:為適應環境做出的改變用戶支持:為軟件用戶提供的支持 新的不同視角下的運維 “管理”的…

Java 面試指南合集

線程篇 springBoot篇 待更新 黑夜無論怎樣悠長,白晝總會到來。 此文會一直更新哈 如果你希望成功,當以恒心為良友,以經驗為參謀,以當心為兄弟,以希望為哨兵。

拉普拉斯變換與卷積

前面描述 卷積,本文由卷積引入拉普拉斯變換。 拉普拉斯變換就是給傅里葉變換的 iωt 加了個實部,也可以反著理解,原函數乘以 e ? β t e^{-\beta t} e?βt 再做傅里葉變換,本質上都是傅里葉變換的擴展。 加入實部的拉普拉斯變…

【建設方案】智慧園區大數據云平臺建設方案(DOC原件)

大數據云平臺建設技術要點主要包括以下幾個方面: 云計算平臺選擇:選擇安全性高、效率性強、成本可控的云計算平臺,如阿里云、騰訊云等,確保大數據處理的基礎環境穩定可靠。 數據存儲與管理:利用Hadoop、HBase等分布式…

一年Java轉GO|19K|騰訊 CSIG 一二面經

面經哥只做互聯網社招面試經歷分享,關注我,每日推送精選面經,面試前,先找面經哥 背景 學歷:本科工作經驗:一年(不算實習)當前語言:Javabase:武漢部門\崗位:騰訊云? 一…

5000天后的世界:科技引領的未來之路

**你是否想過,5000天后的世界會是什么樣子?** 科技日新月異,改變著我們的生活方式,也引領著人類文明的進程。著名科技思想家凱文凱利在他的著作《5000天后的世界》中,對未來進行了大膽的預測。 **這本書中&#xff0c…

基于微信小程序的在線點餐系統【前后臺+附源碼+LW】

摘 要 隨著社會的發展,社會的各行各業都在利用信息化時代的優勢。計算機的優勢和普及使得各種信息系統的開發成為必需。 點餐小程序,主要的模塊包括實現管理員;管理員用戶,可以對整個系統進行基本的增刪改查,系統的日…

什么是<meta> 標簽

<meta> 標簽是 HTML 文檔頭部 (<head>) 中的一種元數據標簽&#xff0c;用于提供關于 HTML 文檔的信息。雖然它不會直接影響文檔的呈現&#xff0c;但它在搜索引擎優化 (SEO)、瀏覽器行為和文檔元信息方面起著重要作用。以下是一些常見的 <meta> 標簽及其用途…

Opencv+python模板匹配

我們經常玩匹配圖像或者找相似&#xff0c;opencv可以很好實現這個簡單的小功能。 模板是被查找目標的圖像&#xff0c;查找模板在原始圖像中的哪個位置的過程就叫模板匹配。OpenCV提供的matchTemplate()方法就是模板匹配方法&#xff0c;其語法如下&#xff1a; result cv2.…

使用go語言來完成復雜excel表的導出導入

使用go語言來完成復雜excel表的導出導入&#xff08;一&#xff09; 1.復雜表的導入 開發需求是需要在功能頁面上開發一個excel文件的導入導出功能&#xff0c;這里的復雜指定是表內數據夾雜著一對多&#xff0c;多對一的形式&#xff0c;如下圖所示。數據雜亂而且對應不統一。…

中國90米分辨率可蝕性因子K數據

土壤可蝕性因子&#xff08;K&#xff09;數據&#xff0c;基于多種土壤屬性數據計算&#xff0c;所用數據包括土壤黏粒含量&#xff08;%&#xff09;、粉粒含量&#xff08;%&#xff09;、砂粒含量&#xff08;%&#xff09;、土壤有機碳含量&#xff08;g/kg&#xff09;、…

[DALL·E 2] Hierarchical Text-Conditional Image Generation with CLIP Latents

1、目的 CLIP DDPM進行text-to-image生成 2、數據 (x, y)&#xff0c;x為圖像&#xff0c;y為相應的captions&#xff1b;設定和為CLIP的image和text embeddings 3、方法 1&#xff09;CLIP 學習圖像和文本的embedding&#xff1b;在訓練prior和decoder時固定該部分參數 2&a…

開放式耳機什么牌子好一點?親檢的幾款開放式藍牙耳機推薦

不入耳的開放式耳機更好一些&#xff0c;不入耳式耳機佩戴更舒適&#xff0c;適合長時間佩戴&#xff0c;不會引起強烈的壓迫感或耳部不適。不入耳式的設計不需要接觸耳朵&#xff0c;比入耳式耳機更加衛生且不挑耳型&#xff0c;因此備受運動愛好者和音樂愛好者的喜愛。這里給…

MySQL中ALTER LOGFILE GROUP 語句詳解

在 MySQL 的 InnoDB 存儲引擎中&#xff0c;ALTER LOGFILE GROUP 語句用于修改重做日志組&#xff08;redo log group&#xff09;的配置。重做日志是 InnoDB 用來保證事務持久性的一個關鍵組件&#xff0c;它們用于在系統崩潰后恢復數據。 InnoDB 支持多個重做日志組&#xf…

周轉車配料揀貨方案

根據周轉車安裝的電子標簽&#xff0c;被懸掛的掃碼器掃到墨水屏顯示的二維碼&#xff0c;投屏發送配料揀貨的數據。 方便快捷分揀物料

20240625(周二)歐美股市總結:標普納指止步三日連跌,英偉達反彈6.8%,谷歌微軟新高,油價跌1%

美聯儲理事鮑曼鷹派發聲&#xff0c;若通脹沒有持續改善將支持加息&#xff0c;加拿大5月CPI重新加速&#xff0c;對加拿大央行7月降息構成阻礙。美股走勢分化&#xff0c;道指收跌近300點且六日里首跌&#xff0c;英偉達市值重上3.10萬億美元&#xff0c;芯片股指顯著反彈1.8%…

想要用tween實現相機的移動,three.js渲染的canvas畫布上相機位置一點沒動,如何解決??

&#x1f3c6;本文收錄于「Bug調優」專欄&#xff0c;主要記錄項目實戰過程中的Bug之前因后果及提供真實有效的解決方案&#xff0c;希望能夠助你一臂之力&#xff0c;幫你早日登頂實現財富自由&#x1f680;&#xff1b;同時&#xff0c;歡迎大家關注&&收藏&&…