十四、OPTIM

一、torch.optim

torch.optim.Optimizer(params, defaults)優化器官網說明
在這里插入圖片描述
由官網給的使用說明打開看出來優化器實驗步驟:

①構造選擇優化器

例如采用隨機梯度下降優化器SGD
torch.optim.SGD(beyond.parameters(),lr=0.01),放入beyond模型的參數parameters;學習率learning rate;
每個優化器都有其特定獨有的參數

②把網絡中所有的可用梯度全部設置為0

optim.zero_grad()
梯度為tensor中的一個屬性,這就是為啥神經網絡傳入的數據必須是tensor數據類型的原因,grad這個屬性其實就是求導,常用在反向傳播中,也就是通過先通過正向傳播依次求出結果,再通過反向傳播求導來依次倒退,其目的主要是對參數進行調整優化,詳細的學習了解可自行百度。

③通過反向傳播獲取損失函數的梯度

result_loss.backward()
這里使用的損失函數為loss,其對象為result_loss,當然也可以使用其他的損失函數
從而得到每個可以調節參數的梯度

④調用step方法,對每個梯度參數進行調優更新

optim.step()
使用優化器的step方法,會利用之前得到的梯度grad,來對模型中的參數進行更新

二、優化器的使用

使用CIFAR-10數據集的測試集,使用之前實現的網絡模型,二、復現網絡模型訓練CIFAR-10數據集

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#構建選擇損失函數為交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)for epoch in range(30):#進行30輪訓練sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把網絡模型中所有的梯度都設置為0result_loss.backward()#反向傳播獲得每個參數的梯度從而可以通過優化器進行調優optim.step()#print(result_loss)sum_loss = sum_loss + result_lossprint(sum_loss)"""
tensor(9431.9678, grad_fn=<AddBackward0>)
tensor(7715.2842, grad_fn=<AddBackward0>)
tensor(6860.3115, grad_fn=<AddBackward0>)
......"""

在optim.zero_grad()及其下面三行處,左擊打個斷點,進入Debug模式(Shift+F9)下,
網絡模型名稱---Protected Attributes---__modules---0-8隨便選一個,例如'0'---weight---grad就是參數的梯度

在這里插入圖片描述

三、自動調整學習速率設置

torch.optim.lr_scheduler.ExponentialLR(optimizer=optim,gamma=0.1)
optimizer為優化器的名稱,gamma表示每次都會將原來的lr乘以gamma
使用optim優化器,每次就會在原來的學習速率的基礎上乘以0.1

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#構建選擇損失函數為交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)
scheduler = ExponentialLR(optimizer=optim,gamma=0.1)#在原來的lr上乘以gammafor epoch in range(30):#進行30輪訓練sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把網絡模型中所有的梯度都設置為0result_loss.backward()#反向傳播獲得每個參數的梯度從而可以通過優化器進行調優optim.step()#print(result_loss)sum_loss = sum_loss + result_lossscheduler.step()#這里就需要不能用優化器,而是使用自動學習速率的優化器print(sum_loss)"""
tensor(9469.4385, grad_fn=<AddBackward0>)
tensor(7144.1514, grad_fn=<AddBackward0>)
tensor(6734.8311, grad_fn=<AddBackward0>)
......"""

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/377545.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/377545.shtml
英文地址,請注明出處:http://en.pswp.cn/news/377545.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Windows下運行jekyll,編碼已不再是問題

很久沒更新jekyll了&#xff0c;所以好奇著去官網看了下更新記錄&#xff0c;發現如下更新條目&#xff08;版本1.3.0/2013-11-04發布&#xff09;&#xff1a; Add encoding configuration option (#1449)之前在windows下安裝jekyll運行編寫的代碼時&#xff0c;如果有中文&am…

leetcode 滑動窗口小結 (二)

目錄424. 替換后的最長重復字符思考分析1優化1004. 最大連續1的個數 III友情提醒方法1&#xff0c;基于當前最大頻數方法2&#xff0c;基于歷史最大頻數424. 替換后的最長重復字符 https://leetcode-cn.com/problems/longest-repeating-character-replacement/ 給你一個僅由大…

軟件故障_一些主要的軟件故障

軟件故障The need for software engineering was realized by the software industry after some of its major failures. Some of these failures are listed below, 在經歷了一些重大失敗之后&#xff0c;軟件行業意識到了對軟件工程的需求 。 下面列出了其中一些故障&#x…

十五、修改VGG16網絡來適應自己的需求

一、VGG-16 VGG-16神經網絡是所訓練的數據集為ImageNet ImageNet數據集中驗證集和測試集一萬五千張&#xff0c;有一千個類別 二、加載VGG-16神經網絡模型 VGG16模型使用說明 torchvision.models.vgg16(pretrainedFalse) 其中參數pretrained表示是否下載已經通過ImageNet數…

leetcode 滑動窗口小結 (三)

目錄978. 最長湍流子數組題目思路分析以及代碼1052. 愛生氣的書店老板題目思考分析與初步代碼優化思路以及優化代碼1208. 盡可能使字符串相等題目思考分析以及代碼978. 最長湍流子數組 https://leetcode-cn.com/problems/longest-turbulent-subarray/ 題目 當 A 的子數組 A[…

JAVA多線程學習3--線程一些方法

一、通過sleep方法睡眠 在指定的毫秒數內讓當前正在執行的線程休眠&#xff08;暫停執行&#xff09;。該線程不丟失任何監視器的所屬權。 二、線程優先級 線程具有優先級&#xff0c;范圍為1-10。 MAX_PRIORITY線程可以具有的最高優先級。int類型&#xff0c;值為10. MIN_PRIO…

mcq 隊列_MCQ | 量子密碼學

mcq 隊列1) Which possible Attacks in Quantum Cryptography can take place? 1)量子密碼術中可能發生哪些攻擊&#xff1f; Possible Attacks in Quantum Cryptography and Birthday Attack 量子密碼術和生日攻擊的可能攻擊 Birthday attack and Boomerang attack 生日襲擊…

《inside the c++ object model》讀書筆記 之一:對象

關于對象 ...引子:在C語言中,"數據"和"處理數據的操作(函數)"是分開來聲明的,語言本身并沒有支持"數據和函數"之間關聯性,這種程序成為"程序性的",由一組"分布在各個一功能為向導的函數中"的算法驅動,他們處理的是共同的外部…

十六、保存和加載自己所搭建的網絡模型

一、保存自己搭建的模型方法一 例如&#xff1a;基于VGG16網絡模型架構的基礎上加上了一層線性層&#xff0c;最后的輸出為10類 torch.save(objmodule,f"path")&#xff0c;傳入需要保存的模型名稱以及要保存的路徑位置 保存模型結構和模型的參數&#xff0c;保存文…

uC/OS-II OS_TASK.C中有關任務管理的函數

函數大致用途 OS_TASK.C是uC/OS-II有關任務管理的文件&#xff0c;它定義了一些函數&#xff1a;建立任務、刪除任務、改變任務的優先級、掛起和恢復任務&#xff0c;以及獲取有關任務的信息。 函數用途OSTaskCreate()建立任務OSTaskCreateExt()擴展建立任務OSTaskStkChk()堆…

windows下寫的腳本,在linux下執行失敗

Windows中的換行符為CRLF, 即正則表達式的rn(ASCII碼為13和10), 而Unix(或Linux)換行符為LF, 即正則表達式的n. 在Windows和Linux下協同工作的時候, 往往這個細小的差別就導致問題, 如 1)Windows下寫的Shell腳本, 在Linux下運行時往往出現rn是無效參數, 不能執行; 2)vi 等編器下…

Scala中的do ... while循環

做...在Scala循環 (do...while loop in Scala) do...while loop in Scala is used to run a block of code multiple numbers of time. The number of executions is defined by an exit condition. If this condition is TRUE the code will run otherwise it runs the first …

十七、完整神經網絡模型訓練步驟

以CIFAR-10數據集為例&#xff0c;訓練自己搭建的神經網絡模型架構 一、準備CIFAR-10數據集 CIFAR10官網使用文檔 torchvision.datasets.CIFAR10(root"./CIFAR_10",trainTrue,downloadTrue) 參數描述root字符串&#xff0c;指明要下載到的位置&#xff0c;或已有數…

μC/OS-Ⅱ 操作系統內核知識

目錄μC/OS-Ⅱ任務調度1.任務控制塊2.任務管理3.任務狀態μC/OS-Ⅱ時間管理μC/OS-Ⅱ內存管理內存控制塊MCBμC/OS-Ⅱ任務通信1.事件2.事件控制塊ECB3.信號量4.郵箱5.消息隊列操作系統內核&#xff1a;在多任務系統中&#xff0c;提供任務調度與切換、中斷服務 操作系統內核為每…

第二版tapout

先說說上次流回來的芯片的測試情況。 4月23日&#xff0c; 芯片采用裸片直接切片&#xff0c; bond在板子上&#xff0c;外面加了一個小塑料殼來保護&#xff0c;我們就直接拿回來測試了。 測試的主要分為模擬和數字兩部分&#xff0c; 數字部分的模塊基本都工作正常&#xff0…

cd-rom門鎖定什么意思_CD-ROM的完整形式是什么?

cd-rom門鎖定什么意思CD-ROM&#xff1a;光盤只讀存儲器 (CD-ROM: Compact Disc Read-Only Memory) CD-ROM is an abbreviation of "Compact Disc Read-Only Memory". It is a data storage memory in the form of an optical compact disc, which is read by a syst…

遠程工作時的協作工具

遠程工作時的協作工具 Google Hangout 用于日常會議和面對面交談,在國內其實可以用qq來帶起。Campfire 用于一天來的持續對話。Screenhero 用于分享屏幕&#xff0c;一起寫代碼,這個比較有用,可以一起寫代碼。Balsamiq 用于計劃要制作的 UI。Asana 用于管理任務Google Docs 用于…

十八、完整神經網絡模型驗證步驟

網絡訓練好了&#xff0c;需要提供輸入進行驗證網絡模型訓練的效果 一、加載測試數據 創建python測試文件&#xff0c;beyond_test.py 保存在dataset文件夾下a文件夾里的1.jpg小狗圖片 二、讀取測試圖片&#xff0c;重新設置模型所規定的大小(32,32)&#xff0c;并轉為tens…

二分法變種小結(leetcode 34、leetcode33、leetcode 81、leetcode 153、leetcode 74)

目錄二分法細節1、leetcode 34 在排序數組中查找元素的第一個和最后一個位置2、不完全有序下的二分查找(leetcode33. 搜索旋轉排序數組)3、含重復元素的不完全有序下的二分查找(81. 搜索旋轉排序數組 II)3、不完全有序下的找最小元素(153. 尋找旋轉排序數組中的最小值)4、二維矩…

ID3D11DeviceContext::Dispatch與numthread筆記

假定——[numthreads(TX, TY, TZ)] // 線程組尺寸。既線程組內有多少個線程。Dispatch(GX, GY, GZ); // 線程組的數量。既有多少個線程組。 那么——SV_GroupThreadID{iTX, iTY, iTZ} // 【線程組內的】線程3D編號SV_GroupID{iGX, iGY, iGZ} // 線程組的3D編號SV_DispatchT…