嵌入式學習-PyTorch(8)-day24

torch.optim 優化器

torch.optim 是 PyTorch 中用于優化神經網絡參數的模塊,里面實現了一系列常用的優化算法,比如 SGD、Adam、RMSprop 等,主要負責根據梯度更新模型的參數。


🏗? 核心組成

1. 常用優化器

優化器作用典型參數
torch.optim.SGD標準隨機梯度下降,支持 momentumlr, momentum, weight_decay
torch.optim.Adam自適應學習率,效果穩定lr, betas, weight_decay
torch.optim.RMSprop平滑梯度,常用于RNNlr, alpha, momentum
torch.optim.AdamW改進版Adam,解耦正則化lr, weight_decay
torch.optim.Adagrad稀疏特征場景,自動調整每個參數的學習率lr, lr_decay, weight_decay

?演示代碼

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(10, 1)  # 一個簡單的線性層
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(100):output = model(torch.randn(4, 10))  # 模擬一個輸入loss = (output - torch.randn(4, 1)).pow(2).mean()  # 假設是 MSE 損失optimizer.zero_grad()  # 梯度清零loss.backward()        # 反向傳播optimizer.step()       # 更新參數

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2ddataset = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)self.maxpool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)self.maxpool3 = nn.MaxPool2d(kernel_size=2)self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_features=1024, out_features=64)self.linear2 = nn.Linear(in_features=64, out_features=10)self.model1 = nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01 )
for epoch in range(100):running_loss = 0.0for data in dataloader:imgs,targets = dataoutputs = tudui(imgs)result_loss = loss(outputs, targets)#梯度置零optim.zero_grad()#反向傳播result_loss.backward()#更新參數optim.step()running_loss += result_lossprint(running_loss)

?

?對網絡模型的修改

import torchvision
from torch import nn# train_data = torchvision.datasets.ImageNet(root='./data_IMG',split="train", transform=torchvision.transforms.ToTensor())
#學習如何改變現有的網絡結構
vgg16_false = torchvision.models.vgg16(pretrained=False)vgg16_true = torchvision.models.vgg16(pretrained=True)train_data = torchvision.datasets.CIFAR10(root='./data_CIF',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#加一個線性層
vgg16_true.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
vgg16_true.classifier.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
#修改一個線性層
vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)
print(vgg16_false)

網絡模型的保存與讀取

#model_save.pyimport torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式一:模型結構+模型參數
torch.save(vgg16,"vgg16.pth")#保存方式二:模型參數(官方推薦)
torch.save(vgg16.state_dict(),"vgg16_state_dict.pth")#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
torch.save(tudui,"tudui_method1.pth")
#model_load.pyimport torch
import torchvisionfrom torch import nn#保存方式一,加載模型
# model = torch.load("vgg16.pth",weights_only=False)
# print(model)#方式二,加載模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# model = torch.load("vgg16_state_dict.pth")
vgg16.load_state_dict(torch.load("vgg16_state_dict.pth"))
# print(vgg16)#陷阱
#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return x#如果直接這么調用的話,機器會找不到類在哪里
# 當你 torch.save(model) 保存整個模型時,它會把整個類的信息序列化。如果加載時當前文件找不到 Tudui 類,自然就炸了。
#可以將定義寫到這個類來,也可以在開頭寫from model_save import *
#!!!更推薦一下模式:
"""
# 保存
torch.save(model.state_dict(), "tudui_method2.pth")# 加載
model = Tudui()
model.load_state_dict(torch.load("tudui_method2.pth"))優點:不管類在哪個文件,只要 Tudui() 存在就能加載;避免因為 class 變動導致報錯;更加靈活,適合后期修改網絡結構。
"""
model = torch.load("tudui_method1.pth",weights_only=False)
print(model)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/89656.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/89656.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/89656.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

PostgreSQL實戰:高效SQL技巧

PostgreSQL PG 在不同領域可能有不同的含義,以下是幾種常見的解釋: PostgreSQL PostgreSQL(簡稱 PG)是一種開源的關系型數據庫管理系統(RDBMS),支持 SQL 標準并提供了豐富的擴展功能。它廣泛應用于企業級應用、Web 服務和數據分析等領域。 PostgreSQL 的詳細介紹 Po…

3-大語言模型—理論基礎:生成式預訓練語言模型GPT(代碼“活起來”)

目錄 1、GPT的模型結構如圖所示 2、介紹GPT自監督預訓練、有監督下游任務微調及預訓練語言模型 2.1、GPT 自監督預訓練 2.1.1、 輸入編碼:詞向量與位置向量的融合 2.1.1.1、 輸入序列與詞表映射 2.1.1.2、 詞向量矩陣與查表操作 3. 位置向量矩陣 4. 詞向量與…

【Redis 】看門狗:分布式鎖的自動續期

在分布式系統的開發中,保證數據的一致性和避免并發沖突是至關重要的任務。Redis 作為一種廣泛使用的內存數據庫,提供了實現分布式鎖的有效手段。然而,傳統的 Redis 分布式鎖在設置了過期時間后,如果任務執行時間超過了鎖的有效期&…

MYSQL--快照讀和當前讀及并發 UPDATE 的鎖阻塞

快照讀和當前讀在 MySQL 中,數據讀取方式主要分為 快照讀 和 當前讀,二者的核心區別在于是否依賴 MVCC(多版本并發控制)的歷史版本、是否加鎖,以及讀取的數據版本是否為最新。以下是詳細說明:一、快照讀&am…

css樣式中的選擇器和盒子模型

目錄 一、行內樣式二、內部樣式三、外部樣式四、結合選擇器五、屬性選擇器六、包含選擇器七、子選擇器八、兄弟選擇器九、選擇器組合十、偽元素選擇器十一、偽類選擇器十二、盒子模型 相關文章 學習標簽、屬性、選擇器和外部加樣式積累CSS樣式屬性:padding、marg…

關于基于lvgl庫做的注冊登錄功能的代碼步驟:

以下是完整的文件拆分和代碼存放說明,按功能模塊化劃分,方便工程管理:一、需要創建的文件清單 文件名 作用 類型 main.c 程序入口,初始化硬件和LVGL 源文件 ui.h 聲明界面相關函數 頭文件 ui.c 實現登錄、注冊、主頁面的UI 源文…

RAII機制以及在ROS的NodeHandler中的實現

好的,這是一個非常核心且優秀的設計問題。我們來分兩步詳細解析:先徹底搞懂什么是 RAII,然后再看 ros::NodeHandle 是如何巧妙地運用這一機制的。1. 什么是 RAII 機制? RAII 是 “Resource Acquisition Is Initialization” 的縮寫…

Linux LVS集群技術

LVS集群概述1、集群概念1.1、介紹集群是指多臺服務器集中在一起,實現同一業務,可以視為一臺計算機。多臺服務器組成的一組計算機,作為一個整體存在,向用戶提供一組網絡資源,這些單個的服務器就是集群的節點。特點&…

spring-ai-alibaba如何上傳文件并解析

問題引出 在我們日常使用大模型時,有一類典型的應用場景,就是將文件發送給大模型,然后由大模型進行解析,提煉總結等,這一類功能在官方app中較為常見,但是在很多模型的api中都不支持,那如何使用…

「雙容器嵌套布局法」:打造清晰層級的網頁架構設計

一、命名與核心概念 “雙容器嵌套布局法”,核心是通過兩層容器嵌套構建網頁結構:外層容器負責控制布局的“宏觀約束”(如頁面最大寬度、背景色等),內層容器聚焦“微觀排版”(內容居中、內邊距調整、紅色內容…

基于深度學習的自然語言處理:構建情感分析模型

前言 自然語言處理(NLP)是人工智能領域中一個非常活躍的研究方向,它致力于使計算機能夠理解和生成人類語言。情感分析(Sentiment Analysis)是NLP中的一個重要應用,其目標是從文本中識別和提取情感傾向&…

JWT原理及利用手法

JWT 原理 JSON Web Token (JWT) 是一種開放的行業標準,用于在系統之間以 JSON 對象的形式安全地傳輸信息。這些信息經過數字簽名,因此可以被驗證和信任。其常用于身份驗證、會話管理和訪問控制機制中傳遞用戶信息。 與傳統的會話令牌相比,JWT…

DeepSeek 助力 Vue3 開發:打造絲滑的日歷(Calendar),日歷_睡眠記錄日歷示例(CalendarView01_30)

前言:哈嘍,大家好,今天給大家分享一篇文章!并提供具體代碼幫助大家深入理解,徹底掌握!創作不易,如果能幫助到大家或者給大家一些靈感和啟發,歡迎收藏關注哦 💕 目錄DeepS…

git的diff命令、Config和.gitignore文件

diff命令:比較git diff xxx:工作目錄 vs 暫存區(比較現在修改之后的工作區和暫存區的內容)git diff --cached xxx:暫存區 vs Git倉庫(現在暫存區內容和最一開始提交的文件內容的比較)git diff H…

Linux中的LVS集群技術

一、實驗環境(RHEL 9)1、NAT模式的實驗環境主機名IP地址網關網絡適配器功能角色client172.25.254.111/24(NAT模式的接口)172.25.254.2NAT模式客戶機lvs172.25.254.100/24(NAT模式的接口)192.168.0.100/24&a…

【數據結構】「隊列」(順序隊列、鏈式隊列、雙端隊列)

- 第 112篇 - Date: 2025 - 07 - 20 Author: 鄭龍浩(仟墨) 文章目錄隊列(Queue)1 基本介紹1.1 定義1.2 棧 與 隊列的區別1.3 重要術語2 基本操作3 順序隊列(循環版本)兩種版本兩種版本區別版本1.1 - rear指向隊尾后邊 且 無 size …

Java行為型模式---解釋器模式

解釋器模式基礎概念解釋器模式(Interpreter Pattern)是一種行為型設計模式,其核心思想是定義一個語言的文法表示,并定義一個解釋器,使用該解釋器來解釋語言中的句子。這種模式將語法解釋的責任分開,使得語法…

[spring6: PointcutAdvisor MethodInterceptor]-簡單介紹

Advice Advice 是 AOP 聯盟中所有增強(通知)類型的標記接口,表示可以被織入目標對象的橫切邏輯,例如前置通知、后置通知、異常通知、攔截器等。 package org.aopalliance.aop;public interface Advice {}BeforeAdvice 前置通知的標…

地圖定位與導航

定位 1.先申請地址權限(大致位置精準位置) module.json5文件 "requestPermissions": [{"name": "ohos.permission.INTERNET" },{"name": "ohos.permission.LOCATION","reason": "$string:app_name",&qu…

【數據結構】揭秘二叉樹與堆--用C語言實現堆

文章目錄1.樹1.1.樹的概念1.2.樹的結構1.3.樹的相關術語2.二叉樹2.1.二叉樹的概念2.2.特殊的二叉樹2.2.1.滿二叉樹2.2.2.完全二叉樹2.3.二叉樹的特性2.4.二叉樹的存儲結構2.4.1.順序結構2.4.2.鏈式結構3.堆3.1.堆的概念3.2.堆的實現3.2.1.堆結構的定義3.2.2.堆的初始化3.2.3.堆…