【深度學習基礎 2】 PyTorch 框架

目錄

一、 PyTorch 簡介

二、安裝 PyTorch

三、PyTorch 常用函數和操作

3.1 創建張量(Tensor)

3.2 基本數學運算

3.3 自動求導(Autograd)

3.4 定義神經網絡模型

3.5 訓練與評估模型

3.6 使用模型進行預測

四、注意事項

五、完整訓練示例代碼


一、 PyTorch 簡介

????????PyTorch 是由 Facebook 開發的開源深度學習框架,以動態計算圖(Dynamic Computational Graph)著稱,允許在運行時即時定義和修改模型結構,便于調試和研究。它支持 GPU 加速,并擁有豐富的生態系統,適用于自然語言處理、計算機視覺等眾多領域。

主要特點:

  • 動態計算圖:每次運行時構建計算圖,便于調試和靈活性高。

  • 自動求導(Autograd):支持自動求導,便于梯度計算與反向傳播。

  • 模塊化設計:通過 torch.nn 提供豐富的神經網絡層及模塊,方便構建復雜模型。

  • 豐富的生態:支持 torchvision、torchtext 等擴展庫,加速模型開發和實驗。

二、安裝 PyTorch

????????可參考YOLO系列環境配置及訓練_yolo環境配置-CSDN博客?中pytorch的安裝方法,以下簡要概括:(以安裝CPU版本為例)

pip install torch torchvision

安裝后,可以通過以下代碼驗證安裝及查看版本:

import torch
print("PyTorch 版本:", torch.__version__)

?CPU版本安裝成功的輸出示例為:

PyTorch 版本: 1.13.0

三、PyTorch 常用函數和操作

3.1 創建張量(Tensor)

????????與TensorFlow一樣,在 PyTorch 中的張量類似于 NumPy 的數組,同時支持 GPU 加速。

例如:

import torch# 創建標量、向量和矩陣
scalar = torch.tensor(5)
vector = torch.tensor([1, 2, 3])
matrix = torch.tensor([[1, 2], [3, 4]])print("標量:", scalar)
print("向量:", vector)
print("矩陣:\n", matrix)

樣例輸出:

標量: tensor(5)
向量: tensor([1, 2, 3])
矩陣:tensor([[1, 2],[3, 4]])

3.2 基本數學運算

????????PyTorch 同樣提供了基本的數學運算,例如:

a = torch.tensor(3.0)
b = torch.tensor(2.0)print("加法:", torch.add(a, b))
print("乘法:", torch.mul(a, b))
# 矩陣乘法
mat1 = torch.tensor([[1, 2]])
mat2 = torch.tensor([[3], [4]])
print("矩陣乘法:\n", torch.matmul(mat1, mat2))

樣例輸出:

加法: tensor(5.)
乘法: tensor(6.)
矩陣乘法:tensor([[11]])

3.3 自動求導(Autograd)

????????PyTorch 的 autograd 功能可以自動計算梯度,非常適合神經網絡反向傳播的實現。

例如:

# 定義一個需要計算梯度的張量
x = torch.tensor(2.0, requires_grad=True)# 定義函數 y = x3 + 2x + 1
y = x**3 + 2*x + 1# 反向傳播,計算 dy/dx
y.backward()print("dy/dx:", x.grad)

樣例輸出:

dy/dx: tensor(14.)

3.4 定義神經網絡模型

????????PyTorch 提供了 torch.nn 模塊來構建神經網絡模型。例如下面使用一個簡單的全連接層構建模型:

import torch.nn as nn
import torch.optim as optim# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(784, 10)  # 輸入 784 維,輸出 10 維(例如手寫數字分類)def forward(self, x):out = self.fc(x)return out# 實例化模型并打印模型結構
model = SimpleNet()
print(model)

樣例輸出:

SimpleNet((fc): Linear(in_features=784, out_features=10, bias=True)
)

3.5 訓練與評估模型

在3.4的基礎上,我們繼續完善構建。

例如:

# 假設我們有一個批次的輸入數據(如手寫數字圖像,已展平為784維向量)
batch_size = 32
dummy_input = torch.randn(batch_size, 784)  # 隨機生成一批數據
dummy_labels = torch.randint(0, 10, (batch_size,))  # 隨機生成對應的標簽# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向傳播
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)print("初始損失:", loss.item())# 反向傳播和優化
optimizer.zero_grad()
loss.backward()
optimizer.step()# 再次輸出損失(注意:僅作為示例,損失值可能不會明顯下降)
outputs_after = model(dummy_input)
loss_after = criterion(outputs_after, dummy_labels)
print("更新后損失:", loss_after.item())

樣例輸出:

初始損失: 2.280543327331543
更新后損失: 2.2781271934509277

注意:?損失值會受到隨機數據和權重初始化的影響,實際訓練中損失下降情況應更為明顯。

3.6 使用模型進行預測

在 3.5 訓練結束后,我們可以通過調用模型的 forward 方法,可以對新的數據進行預測:

# 對一條測試數據進行預測
test_sample = torch.randn(1, 784)
pred_logits = model(test_sample)
pred_label = torch.argmax(pred_logits, dim=1)
print("預測類別:", pred_label.item())

樣例輸出:

預測類別: 7

四、注意事項

????????在訓練過程中,切忌混淆設備(Device),注意將模型和數據遷移到同一設備:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = inputs.to(device)

五、完整訓練示例代碼

import torch
import torch.nn as nn
import torch.optim as optim# 定義一個簡單的神經網絡
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(784, 10)  # 輸入 784 維,輸出 10 維(例如手寫數字分類)def forward(self, x):out = self.fc(x)return out# 實例化模型
model = SimpleNet()
print(model)# 設置設備(如果有 GPU 就用 GPU,否則用 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 假設我們有一個批次的輸入數據(如手寫數字圖像,已展平為784維向量)
batch_size = 32
dummy_input = torch.randn(batch_size, 784).to(device)  # 隨機生成一批數據并遷移到 device
dummy_labels = torch.randint(0, 10, (batch_size,)).to(device)  # 隨機生成對應的標簽并遷移到 device# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向傳播
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)
print("初始損失:", loss.item())# 反向傳播和優化
optimizer.zero_grad()
loss.backward()
optimizer.step()# 再次輸出損失(僅作為示例,損失值可能不會明顯下降)
outputs_after = model(dummy_input)
loss_after = criterion(outputs_after, dummy_labels)
print("更新后損失:", loss_after.item())# 對一條測試數據進行預測
test_sample = torch.randn(1, 784).to(device)
pred_logits = model(test_sample)
pred_label = torch.argmax(pred_logits, dim=1)
print("預測類別:", pred_label.item())

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/74612.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/74612.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/74612.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

uniapp中APP上傳文件

uniapp提供了uni.chooseImage(選擇圖片), uni.chooseVideo(選擇視頻)這兩個api,但是對于打包成APP的話就沒有上傳文件的api了。因此我采用了plus.android中的方式來打開手機的文件管理從而上傳文件。 下面…

推陳換新系列————java8新特性(編程語言的文藝復興)

文章目錄 前言一、新特性秘籍二、Lambda表達式2.1 語法2.2 函數式接口2.3 內置函數式接口2.4 方法引用和構造器引用 三、Stream API3.1 基本概念3.2 實戰3.3 優勢 四、新的日期時間API4.1 核心概念與設計原則4.2 核心類詳解4.2.1 LocalDate(本地日期)4.2…

樹莓派5從零開發至脫機腳本運行教程——1.系統部署篇

樹莓派5應用實例——工創視覺 前言 哈嘍,各位小伙伴,大家好。最近接觸了樹莓派,然后簡單的應用了一下,學習程度并不是很深,不過足夠剛入手樹莓派5的小伙伴們了解了解。后面的幾篇更新的文章都是關于開發樹莓派5的內容…

GPT Researcher 的win docker安裝攻略

github網址是:https://github.com/assafelovic/gpt-researcher 因為docker安裝方法不夠清晰,因此寫一個使用方法 以下是針對 Windows 系統 使用 Docker 運行 AI-Researcher 項目的 詳細分步指南: 步驟 1:安裝 Docker 下載 Docke…

【后端】【Django DRF】從零實現RBAC 權限管理系統

Django DRF 實現 RBAC 權限管理系統 在 Web 應用中,權限管理 是一個核心功能,尤其是在多用戶系統中,需要精細化控制不同用戶的訪問權限。本文介紹如何使用 Django DRF 設計并實現 RBAC(基于角色的訪問控制)系統&…

C#基礎學習(五)函數中的ref和out

1. 引言:為什么需要ref和out? ?問題背景:函數參數默認按值傳遞,值類型在函數內修改不影響外部變量;引用類型重新賦值時外部對象不變。?核心作用:允許函數內部修改外部變量的值,實現“雙向傳參…

八綱辨證總則

一、八綱辨證的核心定義 八綱即陰、陽、表、里、寒、熱、虛、實,是中醫分析疾病共性的綱領性辨證方法。 作用:通過八類證候歸納疾病本質,為所有辨證方法(如臟腑辨證、六經辨證)的基礎。 二、八綱分類與對應關系 1. 總…

【linux重設gitee賬號密碼 克隆私有倉庫報錯】

出現問題時 Cloning into xxx... remote: [session-1f4b16a4] Unauthorized fatal: Authentication failed for https://gitee.com/xxx/xxx.git/解決方案 先打開~/.git-credentials vim ~/.git-credentials或者創建一個 torch ~/.git-credentials 添加授權信息 username/pa…

綠聯NAS安裝內網穿透實現無公網IP也能用手機平板遠程訪問經驗分享

文章目錄 前言1. 開啟ssh服務2. ssh連接3. 安裝cpolar內網穿透4. 配置綠聯NAS公網地址 前言 大家好,今天給大家帶來一個超級炫酷的技能——如何在綠聯NAS上快速安裝cpolar內網穿透工具。想象一下,即使沒有公網IP,你也能隨時隨地遠程訪問自己…

CSS 美化頁面(一)

一、CSS概念 CSS(Cascading Style Sheets,層疊樣式表)是一種用于描述 HTML 或 XML(如 SVG、XHTML)文檔 樣式 的樣式表語言。它控制網頁的 外觀和布局,包括字體、顏色、間距、背景、動畫等視覺效果。 二、CS…

空轉 | GetAssayData doesn‘t work for multiple layers in v5 assay.

問題分析 當我分析多個樣本的時候,而我的seurat又是v5時,通常就會出現這樣的報錯。 錯誤的原因有兩個: 一個是參數名有slot變成layer 一個是GetAssayData 不是自動合并多個layers,而是選擇保留。 那么如果我們想合并多個樣本&…

UE4學習筆記 FPS游戲制作17 讓機器人持槍 銷毀機器人時也銷毀機器人的槍 讓機器人射擊

添加武器插槽 打開機器人的Idle動畫,方便查看武器位置 在動畫面板里打開骨骼樹,找到右手的武器節點,右鍵添加一個插槽,重命名為RightWeapon,右鍵插槽,添加一個預覽資產,選擇Rifle,根…

【JavaScript】七、函數

文章目錄 1、函數的聲明與調用2、形參默認值3、函數的返回值4、變量的作用域5、變量的訪問原則6、匿名函數6.1 函數表達式6.2 立即執行函數 7、練習8、邏輯中斷9、轉為布爾型 1、函數的聲明與調用 function 函數名(形參列表) {函數體 }eg: // 聲明 function sayHi…

硬件基礎--05_電壓

電壓(電勢差) 有了電壓,電子才能持續且定向移動起來,所有電壓是形成電流的必要條件。 電壓越大,能“定向移動”起來的電子就越多,電流就會越大。 有電壓的同時,形成閉合回路才會有電流,不是有電壓就有電流…

ES數據過多,索引拆分

公司企微聊天數據存儲在 ES 中,雖然按照企業分儲在不同的ES 索引中,但某些常用的企微主體使用量還是很大。4年中一個索引存儲數據已經達到46多億條數據,占用存儲3.1tb, ES 配置 由于多一個副本,存儲得翻倍,成本考慮…

存儲服務器是指什么

今天小編主要來為大家介紹存儲服務器主要是指什么,存儲服務器與傳統的物理服務器和云服務器是不同的,其是為了特定的目標所設計的,在硬件配置方式上也有著一定的區別,存儲空間會根據需求的不同而改變。 存儲服務器中一般會配備大容…

golang不使用鎖的情況下,對slice執行并發寫操作,是否會有并發問題呢?

背景 并發問題最簡單的解決方案加個鎖,但是,加鎖就會有資源爭用,提高并發能力其中的一個優化方向就是減少鎖的使用。 我在之前的這篇文章《開啟多個協程,并行對struct中的每個元素操作,是否會引起并發問題?》中討論過多協程場景下struct的并發問題。 Go語言中的slice在…

Java知識整理round1

一、常見集合篇 1. 為什么數組索引從0開始呢?假如從1開始不行咩 數組(Array):一種用連續的內存空間存儲相同數據類型數據的線性數據結構 (1)在根據數組索引獲取元素的時候,會用索引和尋址公式…

【C++指針】搭建起程序與內存深度交互的橋梁(下)

🔥🔥 個人主頁 點擊🔥🔥 每文一詩 💪🏼 往者不可諫,來者猶可追——《論語微子篇》 譯文:過去的事情已經無法挽回,未來的歲月還可以迎頭趕上。 目錄 C內存模型 new與…

JavaScript創建對象的多種方式

在JavaScript中,創建對象有多種方式,每種方式都有其優缺點。本文將介紹四種常見的對象創建模式:工廠模式、構造函數模式、原型模式和組合模式,并分析它們的特點以及如何優化。 1. 工廠模式 工廠模式是一種簡單的對象創建方式&am…