Lnton羚通關于Optimization在【PyTorch】中的基礎知識

OPTIMIZING MODEL PARAMETERS (模型參數優化)
現在我們有了模型和數據,是時候通過優化數據上的參數來訓練了,驗證和測試我們的模型。訓練一個模型是一個迭代的過程,在每次迭代中,模型會對輸出進行猜測,計算猜測數據與真實數據的誤差(損失),收集誤差對其參數的導數(正如前一節我們看到的那樣),并使用梯度下降優化這些參數。

Prerequisite Code ( 先決代碼 )
We load the code from the previous sections on

import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transformstraining_data = datasets.FashionMNIST(root = "../../data/",train = True,download = True, transform = transforms.ToTensor()
)test_data = datasets.FashionMNIST(root = "../../data/",train = False,download = True, transform = transforms.ToTensor()
)train_dataloader = DataLoader(training_data, batch_size = 32, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle = True)class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10)  )def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return outmodel = NeuralNetwork()

Hyperparameters ( 超參數 )
超參數是可調節的參數,允許控制模型優化過程,不同的超參數會影響模型的訓練和收斂速度。read more

我們定義如下的超參數進行訓練:

Number of Epochs: 遍歷數據集的次數
Batch Size: 每一次使用的數據集大小,即每一次用于訓練的樣本數量
Learning Rate: 每個 batch/epoch 更新模型參數的速度,較小的值會導致較慢的學習速度,而較大的值可能會導致訓練過程中不可預測的行為,例如訓練抖動頻繁,有可能會發散等。

learning_rate = 1e-3
batch_size = 32
epochs = 5

Optimization Loop ( 優化循環 )
我們設置完超參數后,就可以利用優化循環訓練和優化模型;優化循環的每次迭代稱為一個 epoch, 每個 epoch 包含兩個主要部分:

The Train Loop: 遍歷訓練數據集并嘗試收斂到最優參數。
The Validation/Test Loop: 驗證/測試循環—遍歷測試數據集以檢查模型性能是否得到改善。
讓我們簡單地熟悉一下訓練循環中使用的一些概念。跳轉到前面以查看優化循環的完整實現。

Loss Function ( 損失函數 )
當給出一些訓練數據時,我們未經訓練的網絡可能不會給出正確的答案。 Loss function 衡量的是得到的結果與目標值的不相似程度,是我們在訓練過程中想要最小化的 Loss function。為了計算 loss ,我們使用給定數據樣本的輸入進行預測,并將其與真實的數據標簽值進行比較。

常見的損失函數包括nn.MSELoss (均方誤差)用于回歸任務,nn.NLLLoss(負對數似然)用于分類神經網絡。nn.CrossEntropyLoss 結合 nn.LogSoftmax 和 nn.NLLLoss 。

我們將模型的輸出 logits 傳遞給 nn.CrossEntropyLoss ,它將規范化 logits 并計算預測誤差。

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

Optimizer ( 優化器 )
優化是在每個訓練步驟中調整模型參數以減少模型誤差的過程。優化算法定義了如何執行這個過程(在這個例子中,我們使用隨機梯度下降)。所有優化邏輯都封裝在優化器對象中。這里,我們使用 SGD 優化器; 此外,PyTorch 中還有許多不同的優化器,如 ADAM 和 RMSProp ,它們可以更好地用于不同類型的模型和數據。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在訓練的循環中,優化分為3個步驟:

調用 optimizer.zero_grad() 重置模型參數的梯度,默認情況下,梯度是累加的。為了防止重復計算,我們在每次迭代中顯式將他們歸零。
通過調用 loss.backward() 反向傳播預測損失, PyTorch 保存每個參數的損失梯度。
一旦我們有了梯度,我們調用 optimizer.step() 在向后傳遞中收集梯度調整參數。
Full Implementation (完整實現)
我們定義了遍歷優化參數代碼的 train loop, 以及根據測試數據定義了test loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms## 數據集
training_data = datasets.FashionMNIST(root="../../data/",train=True,download=True,transform=transforms.ToTensor()
)test_data = datasets.FashionMNIST(root="../../data/",train=False,download=True,transform=transforms.ToTensor()
)## dataloader
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)## 定義神經網絡
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return out## 實例化模型
model = NeuralNetwork()## 損失函數
loss_fn = nn.CrossEntropyLoss()## 優化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)## 超參數
learning_rate = 1e-3
batch_size = 32
epochs = 5## 訓練循環
def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):# 計算預測和損失pred = model(X)loss = loss_fn(pred, y)## 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")## 測試循環
def test_loop(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")## 訓練網絡
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

Lnton羚通專注于音視頻算法、算力、云平臺的高科技人工智能企業。 公司基于視頻分析技術、視頻智能傳輸技術、遠程監測技術以及智能語音融合技術等, 擁有多款可支持ONVIF、RTSP、GB/T28181等多協議、多路數的音視頻智能分析服務器/云平臺。

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/43027.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/43027.shtml
英文地址,請注明出處:http://en.pswp.cn/news/43027.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python3 0基礎學習----數據結構(基礎+練習)

python 0基礎學習筆記之數據結構 📚 幾種常見數據結構列表 (List)1. 定義2. 實例:3. 列表中常用方法.append(要添加內容) 向列表末尾添加數據.extend(列表) 將可迭代對象逐個添加到列表中.insert(索引,插入內容) 向指定…

8.17校招 內推 面經

綠泡泡: neituijunsir 交流裙,內推/實習/校招匯總表格 1、校招 | 騰訊2024校園招聘全面啟動(內推) 校招 | 騰訊2024校園招聘全面啟動(內推) 2、校招 | 大華股份2024屆全球校園招聘正式啟動(內推) 校招 | 大華股份2024屆全球校園招聘正式啟動(內推) …

國家一帶一路和萬眾創業創新的方針政策指引下,Live Market探索跨境產業的創新發展

現代社會,全球經濟互聯互通,跨境產業也因此而崛起。為了推動跨境產業的創新發展,中國政府提出了“一帶一路”和“萬眾創業、萬眾創新”的方針政策,旨在促進全球經濟的互聯互通和創新發展。在這個大環境下,Live Market積…

Mariadb高可用MHA

本節主要學習了Mariadb高可用MHA的概述,案例如何構建MHA 提示:以下是本篇文章正文內容,下面案例可供參考 一、概述 1、概念 MHA(MasterHigh Availability)是一套優秀的MySQL高可用環境下故障切換和主從復制的軟件。…

合宙Air724UG LuatOS-Air LVGL API--簡介

為何是 LVGL LVGL 是一個開源的圖形庫,它提供了創建嵌入式 GUI 所需的一切,具有易于使用的圖形元素、漂亮的視覺效果和低內存占用的特點。 LVGL特點: 強大的 控件 :按鈕、圖表、列表、滑動條、圖像等 高級圖形引擎:動…

BIO、NIO和AIO

一.引言 何為IO 涉及計算機核心(CPU和內存)與其他設備間數據遷移的過程,就是I/O。數據輸入到計算機內存的過程即輸入,反之輸出到外部存儲(比如數據庫,文件,遠程主機)的過程即輸出。 I/O 描述了計算機系統…

插入排序優化——超越歸并排序的超級算法

插入排序及優化 插入排序算法算法講解數據模擬代碼 優化思路一、二分查找二、copy函數 優化后代碼算法的用途題目:數星星(POJ2352 star)輸入輸出格式輸入格式:輸出格式 輸入輸出樣例輸入樣例輸出樣例 題目講解步驟如下AC 代碼 插入…

HIVE SQL實現分組字符串拼接concat

在Mysql中可以通過group_concat()函數實現分組字符串拼接,在HIVE SQL中可以使用concat_ws()collect_set()/collect_list()函數實現相同的效果。 實例: abc2014B92015A82014A102015B72014B6 1.concat_wscollect_list 非去重拼接 select a ,concat_ws(-…

Linux系統中基于NGINX的代理緩存配置指南

作為一名專業的爬蟲程序員,你一定知道代理緩存在加速網站響應速度方面的重要性。而使用NGINX作為代理緩存服務器,能夠極大地提高性能和效率。本文將為你分享Linux系統中基于NGINX的代理緩存配置指南,提供實用的解決方案,助你解決在…

C語言刷題訓練DAY.8

1.計算單位階躍函數 解題思路&#xff1a; 這個非常簡單&#xff0c;只需要if else語句即可完成 解題代碼&#xff1a; #include <stdio.h>int main() {int t 0;while(scanf("%d",&t)!EOF){if (t > 0)printf("1\n");else if (t < 0)pr…

大模型基礎02:GPT家族與提示學習

大模型基礎&#xff1a;GPT 家族與提示學習 從 GPT-1 到 GPT-3.5 GPT(Generative Pre-trained Transformer)是 Google 于2018年提出的一種基于 Transformer 的預訓練語言模型。它標志著自然語言處理領域從 RNN 時代進入 Transformer 時代。GPT 的發展歷史和技術特點如下: GP…

【校招VIP】java語言類和對象之map、set集合

考點介紹&#xff1a; map、set集合相關內容是校招面試的高頻考點之一。 map和set是一種專門用來進行搜索的容器或者數據結構&#xff0c;其搜索效率與其具體的實例化子類有關系。 『java語言類和對象之map、set集合』相關題目及解析內容可點擊文章末尾鏈接查看&#xff01; …

深入了解Maven(一)

目錄 一.Maven介紹與功能 二.依賴管理 1.依賴的配置 2.依賴的傳遞性 3.排除依賴 4.依賴的作用范圍 5.依賴的生命周期 一.Maven介紹與功能 maven是一個項目管理和構建工具&#xff0c;是基于對象模型POM實現。 Maven的作用&#xff1a; 便捷的依賴管理&#xff1a;使用…

springboot 使用zookeeper實現分布式隊列

一.添加ZooKeeper依賴&#xff1a;在pom.xml文件中添加ZooKeeper客戶端的依賴項。例如&#xff0c;可以使用Apache Curator作為ZooKeeper客戶端庫&#xff1a; <dependency><groupId>org.apache.curator</groupId><artifactId>curator-framework</…

【java安全】Log4j反序列化漏洞

文章目錄 【java安全】Log4j反序列化漏洞關于Apache Log4j漏洞成因CVE-2017-5645漏洞版本復現環境漏洞復現漏洞分析 CVE-2019-17571漏洞版本漏洞復現漏洞分析 參考 【java安全】Log4j反序列化漏洞 關于Apache Log4j Log4j是Apache的開源項目&#xff0c;可以實現對System.out…

英語——構詞法

按照語言一定的規律創造新詞的方法就叫作構詞法。英語中常見的構詞法包括六種:合成法、派生法、轉化法、混合法、截短法和首尾字母結合法。其中后三種將在第四節“縮寫和簡寫”中進行講解。 第一節 合成法 英語構詞法中把兩個單詞連在一起合成一個新詞,前一個詞修飾或限定后…

前端性能優化——包體積壓縮插件,打包速度提升插件,提升瀏覽器響應的速率模式

前端代碼優化 –其他的優化可以具體在網上搜索 壓縮項目打包后的體積大小、提升打包速度&#xff0c;是前端性能優化中非常重要的環節&#xff0c;結合工作中的實踐總結&#xff0c;梳理出一些 常規且有效 的性能優化建議 ue 項目可以通過添加–report命令&#xff1a; "…

innodb索引與算法

B樹主鍵插入 B樹在innodb的插入有三種模式page_last_insert, page_dirction, page_N_direction 而在bustub里面的B樹就是page_N_direction,如果是自增主鍵的話&#xff0c;就是上面這樣的插入法 FIC優化 (DDL) 選擇性統計 覆蓋索引 MMR ICP優化 自適應hash 全文索引 MySQL…

Rust之編寫自動化測試

1、測試函數的構成&#xff1a; 在最簡單的情形下,Rust中的測試就是一個標注有test屬性的函數。屬性 (attribute)是一種用于修飾Rust代碼的元數據。只需要將#[test]添加到關鍵字fn的上一行便可以將函數轉變為測試函數。當測試編寫完成后,我們可以使用cargo test命令來運行測試…

Flink-----Standalone會話模式作業提交流程

1.Flink的Slot特點: 均分隔離內存,不隔離CPU可以共享:同一個job中,不同算子的子任務才可以共享同一個slot,同時在運行的前提是,屬于同一個slot共享組,默認都是“default”2.Slot的數量 與 并行度 的關系 slot 是一種靜態的概念,表示最大的并發上線并行度是個動態的概念…