使?Pytorch構建?個神經?絡

關于torch.nn:

使?Pytorch來構建神經?絡, 主要的?具都在torch.nn包中.
nn依賴于autograd來定義模型, 并對其?動求導.

構建神經?絡的典型流程:

  1. 定義?個擁有可學習參數的神經?絡
  2. 遍歷訓練數據集
  3. 處理輸?數據使其流經神經?絡
  4. 計算損失值
  5. 將?絡參數的梯度進?反向傳播
  6. 以?定的規則更新?絡的權重
我們?先定義?個Pytorch實現的神經?絡:
# 導?若??具包
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定義?個簡單的?絡類
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定義第?層卷積神經?絡, 輸?通道維度=1, 輸出通道維度=6, 卷積核??3*3self.conv1 = nn.Conv2d(1, 6, 3)# 定義第?層卷積神經?絡, 輸?通道維度=6, 輸出通道維度=16, 卷積核??3*3self.conv2 = nn.Conv2d(6, 16, 3)# 定義三層全連接?絡self.fc1 = nn.Linear(16 * 6 * 6, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): # 在(2, 2)的池化窗?下執?最?池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):# 計算size, 除了第0個維度上的batch_sizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)

輸出結果:
Net((conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))(fc1): Linear(in_features=576, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
注意:
模型中所有的可訓練參數, 可以通過net.parameters()來獲得.
params = list(net.parameters())
print(len(params))
print(params[0].size())

輸出結果:
10
torch.Size([6, 1, 3, 3])

輸出結果:
假設圖像的輸?尺?為32 * 32:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

輸出結果:
tensor([[ 0.1242, 0.1194, -0.0584, -0.1140, 0.0661, 0.0191, -0.0966, 
0.0480, 0.0775, -0.0451]], grad_fn=<AddmmBackward>)

有了輸出張量后, 就可以執?梯度歸零和反向傳播的操作了.
net.zero_grad()
out.backward(torch.randn(1, 10))

注意:
  • torch.nn構建的神經?絡只?持mini-batches的輸?, 不?持單?樣本的輸?.
  • ?如: nn.Conv2d 需要?個4D Tensor, 形狀為(nSamples, nChannels, Height, Width). 如果你的輸?只有單?樣本形式, 則需要執?input.unsqueeze(0), 主動將3D Tensor擴充成4D Tensor.
損失函數
  • 損失函數的輸?是?個輸?的pair: (output, target), 然后計算出?個數值來評估output和target之間的差距??.
  • 在torch.nn中有若?不同的損失函數可供使?, ?如nn.MSELoss就是通過計算均?差損失來評估輸?和?標值之間的差距.
應?nn.MSELoss計算損失的?個例?:
output = net(input)
target = torch.randn(10)
# 改變target的形狀為?維張量, 為了和output匹配
target = target.view(1, -1)
criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)

輸出結果:
tensor(1.1562, grad_fn=<MseLossBackward>)

關于?向傳播的鏈條: 如果我們跟蹤loss反向傳播的?向, 使?.grad_fn屬性打印, 將可以看到?張完整的計算圖如下:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss

當調?loss.backward()時, 整張計算圖將對loss進??動求導, 所有屬性requires_grad=True的Tensors都將參與梯度求導的運算, 并將梯度累加到Tensors中的.grad屬性中.
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU

輸出結果:

反向傳播(backpropagation)

在Pytorch中執?反向傳播?常簡便, 全部的操作就是loss.backward().
在執?反向傳播之前, 要先將梯度清零, 否則梯度會在不同的批次數據之間被累加.
執??個反向傳播的?例?:
# Pytorch中執?梯度清零的代碼
net.zero_grad()
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
# Pytorch中執?反向傳播的代碼
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

輸出結果:
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([-0.0002, 0.0045, 0.0017, -0.0099, 0.0092, -0.0044])

更新?絡參數

  • 更新參數最簡單的算法就是SGD(隨機梯度下降).
  • 具體的算法公式表達式為: weight = weight - learning_rate * gradient
?先?傳統的Python代碼來實現SGD如下:
learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)

然后使?Pytorch官?推薦的標準代碼如下:
# ?先導?優化器的包, optim中包含若?常?的優化算法, ?如SGD, Adam等
import torch.optim as optim
# 通過optim創建優化器對象
optimizer = optim.SGD(net.parameters(), lr=0.01)
# 將優化器執?梯度清零的操作
optimizer.zero_grad()
output = net(input)
loss = criterion(output, target)
# 對損失值執?反向傳播的操作
loss.backward()
# 參數的更新通過??標準代碼來執?
optimizer.step()

?節總結

  • 學習了構建?個神經?絡的典型流程:
    • 定義?個擁有可學習參數的神經?絡
    • 遍歷訓練數據集
    • 處理輸?數據使其流經神經?絡
    • 計算損失值
    • 將?絡參數的梯度進?反向傳播
    • 以?定的規則更新?絡的權重
  • 學習了損失函數的定義:
  • 采?torch.nn.MSELoss()計算均?誤差.
  • 通過loss.backward()進?反向傳播計算時, 整張計算圖將對loss進??動求導, 所有屬性
  • requires_grad=True的Tensors都將參與梯度求導的運算, 并將梯度累加到Tensors中
  • 的.grad屬性中.
  • 學習了反向傳播的計算?法:
  • 在Pytorch中執?反向傳播?常簡便, 全部的操作就是loss.backward().
  • 在執?反向傳播之前, 要先將梯度清零, 否則梯度會在不同的批次數據之間被累加.
    • net.zero_grad()
    • loss.backward()
  • 學習了參數的更新?法:
  • 定義優化器來執?參數的優化與更新.
    • optimizer = optim.SGD(net.parameters(), lr=0.01)
  • 通過優化器來執?具體的參數更新.
    • optimizer.step()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/89380.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/89380.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/89380.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

網絡爬蟲的詳細知識點

基本介紹 什么是網絡爬蟲 網絡爬蟲&#xff08;Web Crawler&#xff09;是一種自動化程序&#xff0c;用于從互聯網上抓取、解析和存儲網頁數據。其核心功能是模擬人類瀏覽行為&#xff0c;通過HTTP/HTTPS協議訪問目標網站&#xff0c;提取文本、鏈接、圖片或其他結構化信息&…

AndroidX中ComponentActivity與原生 Activity 的區別

一、AndroidX 與原生 Activity 的區別 1. 概念與背景 原生 Activity&#xff1a;指 Android 早期&#xff08;API 1 起&#xff09;就存在于 android.app 包下的 Activity 類&#xff08;如 android.app.Activity&#xff09;&#xff0c;是 Android 最初的 Activity 實現&…

Spring AI 使用 Elasticsearch 作為向量數據庫

前言 嗨&#xff0c;大家好&#xff0c;我是雪荷&#xff0c;最近在公司開發 AI 知識庫&#xff0c;同時學到了一些 AI 開發相關的技術&#xff0c;這期先與大家分享一下如何用 ES 當做向量數據庫。 安裝ES 第一步我們先安裝 Elasticsearch&#xff0c;這里建議 Elasticsear…

TypeScript 配置全解析:tsconfig.json、tsconfig.app.json 與 tsconfig.node.json 的深度指南

前言在現代前端和后端開發中&#xff0c;TypeScript 已經成為許多開發者的首選語言。然而&#xff0c;TypeScript 的配置文件&#xff08;特別是多個配置文件協同工作時&#xff09;常常讓開發者感到困惑。本文將深入探討 tsconfig.json、tsconfig.app.json 和 tsconfig.node.j…

讀書筆記(學會說話)

1、一個人只有會說話&#xff0c;才會有好人緣&#xff0c;做事才會順利。會說話的人容易成功。善于說話的人易成功&#xff0c;而不善說話的人往往寸步難行。我們要把話說得好聽&#xff0c;同時更要把事做得漂亮。或許一句話&#xff0c;一件事&#xff0c;就可能使人生的旅途…

私有服務器AI智能體搭建-大模型選擇優缺點、擴展性、可開發

以下是主流 AI 框架與模型的對比分析&#xff0c;涵蓋其優缺點、擴展性、可開發性等方面。 文章目錄一、AI 框架對比二、主流大模型對比三、擴展性對比總結四、可開發性對比總結五、選擇建議&#xff08;按場景&#xff09;六、未來趨勢一、AI 框架對比 框架優點缺點擴展性可開…

OpenCV直線段檢測算法類cv::line_descriptor::LSDDetector

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 該類用于實現 LSD (Line Segment Detector) 直線段檢測算法。LSD 是一種快速、準確的直線檢測方法&#xff0c;能夠在不依賴邊緣檢測的前提下直接從…

Go語言流程控制(if / for)

分支結構package mainimport ("fmt""strconv" )/* 1.順序結構 2.分支結構 3.循環結構 *//* if 條件1 {// 條件1為真時執行的代碼 } else if 條件2 {// 條件1為假但條件2為真時執行的代碼 } else {// 所有條件均為假時執行的代碼 }一種特殊的條件分支結構if…

wx小程序設置沉浸式導航文字高度問題

第一步&#xff1a;在app.json中設置"navigationStyle": "custom"第二步驟&#xff1a;文件的home.js中// pages/test/test.js Page({/*** 頁面的初始數據*/data: {statusBarHeight: 0,navBarHeight: 44 // 自定義導航內容區高度(單位px)},/*** 生命周期函…

C++算法競賽篇:DevC++ 如何進行debug調試

C算法競賽篇&#xff1a;DevC 如何進行debug調試前言一、準備工作&#xff1a;編譯生成可執行程序二、核心步驟&#xff1a;設置斷點與啟動調試1. 設置斷點2. 啟動調試模式三、調試操作&#xff1a;逐步執行與變量監控1. 逐步執行代碼2. 監控變量值變化四、調試結束前言 在算法…

語音大模型速覽(三)- cosyvoice2

CosyVoice 2: Scalable Streaming Speech Synthesis with Large Language Models 論文鏈接&#xff1a;https://arxiv.org/pdf/2412.10117代碼鏈接&#xff1a;https://github.com/FunAudioLLM/CosyVoice 一句話總結 CosyVoice 2 是一款改進的流式語音合成模型&#xff0c;其…

-lstdc++與-static-libstdc++的用法和差異

CMakeLists.txt 里寫了&#xff1a; target_link_libraries(${PROJECT_NAME} PRIVATEgccstdc ) target_link_options(${PROJECT_NAME} PRIVATE -static-libstdc)看起來像是“鏈接了兩次 C 標準庫”&#xff0c;其實它們的作用完全不同&#xff1a;1. target_link_libraries(...…

Redis學習其二(事務,SpringBoot整合,持久化RDB和AOF)

文章目錄5,事務5.1Redis 事務不保證原子性的原因5.2事務操作過程5.3監控6,SpringBoot整合Redis6.1Redis客戶端6.1.1Jedis簡單使用6.1.2Lettuce&Jedis6.2配置相關6.3使用6.3.1使用RedisTemplate6.3.2Redis工具類7,持久化RDB7.1RDB持久化原理7.2觸發機制save命令flushall命令…

springboot項目部署到K8S

java后臺 創建harbor鏡像拉取Secret&#xff1a;kubectl create secret docker-registry harbor-regcred \--docker-server \ #harbor倉庫地址--docker-username \ #harbor 賬號--docker-password \ #harbor密碼-n productionDockerfile FROM *harbor地址*/library/custom-jdk…

【FPGA開發】一文輕松入門Modelsim的基本操作

Modelsim仿真的步驟 &#xff08;1&#xff09;創建新的工程。 &#xff08;2&#xff09;在彈出的窗口中&#xff0c;確定項目名和工作路徑&#xff0c;庫保持為work不變(如有需要可以根據需求進行更改)。 &#xff08;3&#xff09;添加已經存在的文件&#xff08;rtl代碼和t…

服務攻防-Java組件安全FastJson高版本JNDI不出網C3P0編碼繞WAF寫入文件CI鏈

服務攻防-Java組件安全&FastJson&高版本JNDI&不出網C3P0&編碼繞WAF&寫入文件CI鏈26天 原創 朝陽 Sec朝陽 2025年07月18日 09:23 湖北 標題已修改 演示環境&#xff1a; https://github.com/lemono0/FastJsonParty FastJson全版本Docker漏洞環境(涵蓋1.…

【Python】DRF核心組件詳解:Mixin與Generic視圖

在 Django REST Framework (DRF) 中&#xff0c;mixins.CreateModelMixin、mixins.ListModelMixin、GenericAPIView 和 GenericViewSet 是構建 API 視圖的核心組件。以下是對這些組件的主要方法及其職責的簡要說明&#xff0c;內容清晰且結構化&#xff1a;1. mixins.CreateMod…

HTML+CSS+JS基礎

文章目錄&#xff08;一&#xff09;html1.常見標簽&#xff08;1&#xff09;注釋&#xff08;2&#xff09;標題 h1~h6&#xff08;3&#xff09;段落 p&#xff08;4&#xff09;換行與空格 br \ &#xff08;5&#xff09;格式化標簽 b i s u&#xff08;6&#xff09;…

Vue導出Html為Word中包含圖片在Microsoft Word顯示異常問題

問題背景 碰到一個問題&#xff1a;將包含圖片和SVG數學公式的HTML內容導出為Word文檔時&#xff0c;將圖片都轉為ase64格式導出&#xff0c;在WPS Word中顯示正常&#xff0c;但是在Microsoft Word中出現圖片示異常。具體問題表現 WPS兼容性&#xff1a;在WPS中顯示正常&#…

橢圓曲線密碼學 Elliptic Curve Cryptography

密碼學是研究在存在對抗行為的情況下還能安全通信的技術。即算法加密信息&#xff0c;再算法解密出信息。加密分為兩類 1. Symmetric-key Encryption (secret key encryption) 即一種密鑰&#xff0c;加密和解密使用同一密鑰&#xff0c;可相互轉換 2. Asymmetric-key Encry…