PyTorch 學習筆記

環境:python3.8 + PyTorch2.4.1+cpu + PyCharm

參考鏈接:

快速入門 — PyTorch 教程 2.6.0+cu124 文檔

PyTorch 文檔 — PyTorch 2.4 文檔

快速入門

導入庫

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

加載數據集

使用 FashionMNIST 數據集。每個 TorchVision 都包含兩個參數: 分別是 修改樣本 和 標簽。

# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",     # 數據集存儲的位置train=True,      # 加載訓練集(True則加載訓練集)download=True,   # 如果數據集在指定目錄中不存在,則下載(True才會下載)transform=ToTensor(), # 應用于圖像的轉換列表,例如轉換為張量和歸一化
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,     # 加載測試集(False則加載測試集)download=True,transform=ToTensor(),
)

創建數據加載器

batch_size = 64# Create data loaders.
# DataLoader():batch_size每個批次的大小,shuffle=True則打亂數據
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader: # 遍歷訓練數據加載器,x相當于圖片,y相當于標簽print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

?

創建模型

為了在 PyTorch 中定義神經網絡,我們創建一個繼承 來自?nn.模塊。我們定義網絡的各層 ,并在函數中指定數據如何通過網絡。要加速 作,我們將其移動到 CUDA、MPS、MTIA 或 XPU 等加速器。如果當前加速器可用,我們將使用它。否則,我們使用 CPU。__init__forward

#使用加速器,并打印當前使用的加速器(當前加速器可用則使用當前的,否則使用cpu)
# device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" # torch2.4.2并沒有accelerator這個屬性,2.6的才有,所以注釋掉
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")# 檢查 CUDA 是否可用
print("CUDA available:", torch.cuda.is_available())# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device) #torch2.4.2并沒有accelerator這個屬性,2.6的才有,所以注釋掉不用
# model = NeuralNetwork()
print(model)

優化模型參數

要訓練模型,我們需要一個損失函數和一個優化器:

loss_fn = nn.CrossEntropyLoss() # 損失函數,nn.CrossEntropyLoss()用于多分類
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # 優化器,用于更新模型的參數,以最小化損失函數
'''
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
優化器用PyTorch 提供的隨機梯度下降(Stochastic Gradient Descent, SGD)優化器
model.parameters():將模型的參數傳遞給優化器,優化器會根據這些參數計算梯度并更新它們
lr=1e-3:學習率(learning rate),控制每次參數更新的步長
(較大的學習率可能導致訓練不穩定,較小的學習率可能導致訓練速度變慢)
'''

在單個訓練循環中,模型對訓練集進行預測(分批提供給它),并且 反向傳播預測誤差以調整模型的參數:

'''
訓練模型(單個epoch)
dataloader:數據加載器,用于按批次加載訓練數據
model     :神經網絡模型
loss_fn   :損失函數,用于計算預測值與真實值之間的誤差
optimizer :優化器,用于更新模型參數
'''
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train() # 將模型設置為訓練模式(啟用 dropout 和 batch normalization 的訓練行為)for batch, (X, y) in enumerate(dataloader): # 遍歷 dataloader 中的每個批次,獲取輸入 X 和標簽 yX, y = X.to(device), y.to(device) # 將數據移動到指定設備(如 GPU 或 CPU)# Compute prediction error# 計算預測損失,同時也是前向傳播pred = model(X)         # 模型的預測值,即模型的輸出loss = loss_fn(pred, y) # 計算損失:y為實際的類別標簽# Backpropagation 反向傳播和優化# 梯度清零應在每次反向傳播之前執行,以避免梯度累積(先用optimizer.zero_grad())loss.backward()       # 計算梯度optimizer.step()      # 使用優化器更新模型參數optimizer.zero_grad() # 清除之前的梯度(清零梯度,為下一輪計算做準備)# 梯度清零應在每次反向傳播之前執行,以避免梯度累積(在計算模型預測值前先用optimizer.zero_grad())if batch % 100 == 0: # 每 100 個批次打印一次損失值和當前處理的樣本數量loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

?進度條顯示

  • 如果數據集較大,訓練過程可能較慢。可以使用?tqdm?庫添加進度條,提升用戶體驗。例如:
from tqdm import tqdm
for batch, (X, y) in enumerate(tqdm(dataloader, desc="Training")):...

我們還根據測試集檢查模型的性能,以確保它正在學習:

# 測試模型
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 測試集的總樣本數num_batches = len(dataloader)   # 測試數據加載器(dataloader)的總批次數model.eval()                    # 設置為評估模式,這會關閉 dropout 和 batch normalization 的訓練行為test_loss, correct = 0, 0       # 累積測試損失和正確預測的樣本數with torch.no_grad(): # 禁用梯度計算,使用 torch.no_grad() 上下文管理器,避免計算梯度,從而節省內存并加速計算for X, y in dataloader:X, y = X.to(device), y.to(device) # 將數據加載到指定設備pred = model(X) # 模型預測test_loss += loss_fn(pred, y).item() # 累積損失correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 累積正確預測數# correct += (pred.argmax(1) == y).float().sum().item()  # 可以直接使用 .float(),更簡潔test_loss /= num_batches    # 平均損失correct /= size             # 準確率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 累積正確預測數
# correct += (pred.argmax(1) == y).float().sum().item()  # 可以直接使用 .float(),更簡潔
'''pred.argmax(1):
pred 是模型的輸出(通常是未經過 softmax 的 logits,形狀為 [batch_size, num_classes])。
argmax(1) 表示在第二個維度(即類別維度)上找到最大值的索引,返回一個形狀為 [batch_size] 的張量,表示每個樣本的預測類別。pred.argmax(1) == y:
y 是真實標簽(形狀為 [batch_size]),表示每個樣本的真實類別。
這一步會比較預測的類別和真實類別,返回一個布爾張量,形狀為 [batch_size],其中每個元素表示對應樣本的預測是否正確。.type(torch.float):
將布爾張量轉換為浮點數張量(True 轉為 1.0,False 轉為 0.0).sum():
對浮點數張量求和,得到預測正確的樣本總數.item():
將結果從張量轉換為 Python 的標量(整數)
'''
  • 舉例一:

pred?是模型的輸出:torch.tensor([[2.5, 0.3, 0.2], [0.1, 3.2, 0.7]])

y?是真實標簽:torch.tensor([0, 1])

import torchpred = torch.tensor([[2.5, 0.3, 0.2], [0.1, 3.2, 0.7]])
y = torch.tensor([0, 1])correct = (pred.argmax(1) == y).type(torch.float).sum().item()
print(correct)  # 輸出: 2.0 -> 轉換為整數后為 2
  • 舉例二:
import torch# 模型輸出(未經過 softmax 的 logits)
pred = torch.tensor([[2.0, 1.0, 0.1],  # 第一個樣本的預測分數[0.5, 3.0, 0.2],  # 第二個樣本的預測分數[1.2, 0.3, 2.5]]) # 第三個樣本的預測分數# 真實標簽
y = torch.tensor([0, 1, 2])  # 第一個樣本的真實類別是 0,第二個是 1,第三個是 2# 計算預測正確的樣本數
correct = (pred.argmax(1) == y).type(torch.float).sum().item()
print(f"預測正確的樣本數: {correct}") # 預測正確的樣本數: 3'''逐步分析
對每個樣本的預測分數取最大值的索引,得到預測類別:
pred.argmax(1)  # 輸出: tensor([0, 1, 2])比較預測類別和真實標簽,得到布爾張量:
pred.argmax(1) == y  # 輸出: tensor([True, True, True]).type(torch.float): 將布爾張量轉換為浮點數張量:
(pred.argmax(1) == y).type(torch.float)  # 輸出: tensor([1.0, 1.0, 1.0]).sum(): 對浮點數張量求和,得到預測正確的樣本總數:
(pred.argmax(1) == y).type(torch.float).sum()  # 輸出: tensor(3.0).item():將結果從張量轉換為 Python 標量:
(pred.argmax(1) == y).type(torch.float).sum().item()  # 輸出: 3在這個例子中,模型對所有 3 個樣本的預測都正確,因此預測正確的樣本數為 3。
'''
# 如果知道總樣本數,可以進一步計算準確率:# 總樣本數
total = len(y)# 準確率
accuracy = correct / total
print(f"準確率: {accuracy * 100:.2f}%")

訓練過程分多次迭代 (epoch)?進行。在每個 epoch 中,模型會學習 參數進行更好的預測。

然后打印模型在每個 epoch 的準確率和損失,

期望看到 準確率Accuracy增加,損失Avg loss隨著每個 epoch 的減少而減少:

# 跑5輪,每輪皆是先訓練,然后測試
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

輸出:

Epoch 1
-------------------------------
loss: 2.308106  [   64/60000]
loss: 2.292096  [ 6464/60000]
loss: 2.280747  [12864/60000]
loss: 2.273108  [19264/60000]
loss: 2.256617  [25664/60000]
loss: 2.240094  [32064/60000]
loss: 2.229981  [38464/60000]
loss: 2.204926  [44864/60000]
loss: 2.201917  [51264/60000]
loss: 2.178733  [57664/60000]
Test Error: Accuracy: 46.1%, Avg loss: 2.164820 Epoch 2
-------------------------------
loss: 2.178193  [   64/60000]
loss: 2.160645  [ 6464/60000]
loss: 2.110801  [12864/60000]
loss: 2.129119  [19264/60000]
loss: 2.078400  [25664/60000]
loss: 2.029629  [32064/60000]
loss: 2.044328  [38464/60000]
loss: 1.972220  [44864/60000]
loss: 1.980023  [51264/60000]
loss: 1.920835  [57664/60000]
Test Error: Accuracy: 56.2%, Avg loss: 1.906657 Epoch 3
-------------------------------
loss: 1.938616  [   64/60000]
loss: 1.902610  [ 6464/60000]
loss: 1.797264  [12864/60000]
loss: 1.844325  [19264/60000]
loss: 1.726765  [25664/60000]
loss: 1.688332  [32064/60000]
loss: 1.695883  [38464/60000]
loss: 1.605903  [44864/60000]
loss: 1.628846  [51264/60000]
loss: 1.532240  [57664/60000]
Test Error: Accuracy: 59.8%, Avg loss: 1.541237 Epoch 4
-------------------------------
loss: 1.604458  [   64/60000]
loss: 1.563167  [ 6464/60000]
loss: 1.426733  [12864/60000]
loss: 1.503305  [19264/60000]
loss: 1.376496  [25664/60000]
loss: 1.381424  [32064/60000]
loss: 1.371971  [38464/60000]
loss: 1.312882  [44864/60000]
loss: 1.342990  [51264/60000]
loss: 1.244696  [57664/60000]
Test Error: Accuracy: 62.7%, Avg loss: 1.268371 Epoch 5
-------------------------------
loss: 1.344515  [   64/60000]
loss: 1.318664  [ 6464/60000]
loss: 1.166471  [12864/60000]
loss: 1.275481  [19264/60000]
loss: 1.146058  [25664/60000]
loss: 1.179018  [32064/60000]
loss: 1.171105  [38464/60000]
loss: 1.129168  [44864/60000]
loss: 1.163182  [51264/60000]
loss: 1.077062  [57664/60000]
Test Error: Accuracy: 64.7%, Avg loss: 1.097442 Done!

保存模型

保存模型的常用方法是序列化內部狀態字典(包含模型參數):

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

加載模型

加載模型的過程包括重新創建模型結構和加載 state 字典放入其中。

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))

查看安裝的PyTorch版本

方法一:cmd終端查看

終端中輸入:

>>>python
>>>import torch
>>>torch.__version__  //注意version前后是兩個下劃線

方法二:PyCharm查看

打開Pycharm,在Python控制臺中輸入:

或者在Pycharm的“Python軟件包”中查看:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900604.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900604.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900604.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

windows開啟wsl與輕量級虛擬機管理

基于win 10 打造K8S應用開發環境(wsl & kind) 一、wsl子系統安裝 1.1 確認windows系統版本 cmd/powershell 或者win r 運行winver 操作系統要> 19044 1.2 開啟wsl功能 控制面板 -> 程序 -> 啟用或關閉Windows功能 開啟適用于Linux的…

C++ -異常之除以 0 問題(整數除以 0 編譯時檢測、整數除以 0 運行時檢測、浮點數除以 0 編譯時檢測、浮點數除以 0 運行時檢測)

一、整數除以 0&#xff08;編譯時檢測&#xff09; 1、演示 #include <iostream>using namespace std;int main() {int result 10 / 0;cout << result << endl;return 0; }程序無法運行&#xff0c;輸出結果 error C2124: 被零除或對零求模2、演示解讀 …

【藍橋杯】搜索算法:剪枝技巧+記憶化搜索

1. 可行性剪枝應用 1.1. 題目 題目描述: 給定一個正整數n和一個正整數目標值target,以及一個由不同正整數組成的數組nums。要求從nums中選出若干個數,每個數可以被選多次,使得這些數的和恰好等于target。問有多少種不同的組合方式? 輸入: 第一行:n和target,表示數組…

Uniapp 集成極光推送(JPush)完整指南

文章目錄 前言一、準備工作1. 注冊極光開發者賬號2. 創建應用3. Uniapp項目準備 二、集成極光推送插件方法一&#xff1a;使用UniPush&#xff08;推薦&#xff09;方法二&#xff1a;手動集成極光推送SDK 三、配置原生平臺參數四、核心功能實現1. 獲取RegistrationID2. 設置別…

Linux中進程

一、認識進程 進程(PCB)內核數據結構(task_struct)程序的代碼和數據 每一個進程都有其獨立的task_struct,OS對眾多的task_struct進行管理&#xff0c;如何管理&#xff1f;先描述再組織&#xff0c;所有運?在系統?的進程都以task_struct鏈表的形式存在內核?&#xff0c;而…

國外的AI工具

一 OpenAI &#xff1a; &#x1f4a1; 總覽&#xff1a; 名稱全稱/代號簡介GPT-4o“o” omniOpenAI 最新的旗艦多模態模型&#xff08;文字、圖像、音頻三模態&#xff09;&#xff0c;比 GPT-4 更強、更快、更便宜。GPT-4o-mini精簡版 GPT-4o輕量級版本&#xff0c;推測為性…

企業級Java開發工具MyEclipse v2025.1——支持AI編碼輔助

MyEclipse一次性提供了巨量的Eclipse插件庫&#xff0c;無需學習任何新的開發語言和工具&#xff0c;便可在一體化的IDE下進行Java EE、Web和PhoneGap移動應用的開發&#xff1b;強大的智能代碼補齊功能&#xff0c;讓企業開發化繁為簡。 立即獲取MyEclipse v2025.1正式版 具…

按鍵長按代碼

這些代碼都存放在定時器中斷中。中斷為100ms中斷一次。 數據判斷&#xff0c;看的懂就看吧

在 macOS 上連接 PostgreSQL 數據庫(pgAdmin、DBeaver)

在 macOS 上連接 PostgreSQL 數據庫 pgAdmin 官方提供的圖形化管理工具&#xff0c;支持 macOS。 下載地址&#xff1a;https://www.pgadmin.org/ pgAdmin 4 是對 pgAdmin 的完全重寫&#xff0c;使用 Python、ReactJs 和 Javascript 構建。一個用 Electron 編寫的桌面運行時…

FTP協議和win server2022安裝ftp

FTP協議簡介 FTP&#xff08;File Transfer Protocol&#xff0c;文件傳輸協議&#xff09;是一種用于在網絡上的計算機之間傳輸文件的標準網絡協議。它被廣泛應用于服務器與客戶端之間的文件上傳、下載以及管理操作。FTP支持多種文件類型和結構&#xff0c;并提供了相對簡單的…

人工智能——AdaBoost算法

目錄 摘要 13 AdaBoost算法 13.1 本章工作任務 13.2 本章技能目標 13.3 本章簡介 13.4 編程實戰 13.5 本章總結 13.6 本章作業 本章已完結! 摘要 本章實現的工作是:首先采用Python語言讀取數據并構造訓練集和測試集。然后建立AdaBoost模型,利用訓練集訓練該模型,…

DFS 藍橋杯

最大數字 問題描述 給定一個正整數 NN 。你可以對 NN 的任意一位數字執行任意次以下 2 種操 作&#xff1a; 將該位數字加 1 。如果該位數字已經是 9 , 加 1 之后變成 0 。 將該位數字減 1 。如果該位數字已經是 0 , 減 1 之后變成 9 。 你現在總共可以執行 1 號操作不超過 A…

【開發經驗】調試OpenBMC Redfish EventService功能

EventService功能是Redfish規范中定義的一種事件日志的發送方式。用戶可以設置訂閱者信息(通常是一個web服務器)&#xff0c;當產生事件日志時&#xff0c;OpenBMC可以根據用戶設置的訂閱者信息與對日志的篩選設置&#xff0c;將事件日志發送到訂閱者。 相比于傳統的SNMPTrap日…

中斷嵌套、中斷咬尾、中斷晚到

中斷咬尾&#xff08;Tail-Chaining&#xff09;是一種通過減少上下文切換開銷來實現中斷連續響應的高效機制&#xff0c;其核心在于避免重復的出棧和入棧操作&#xff0c;從而顯著降低中斷延遲。以下是具體原理及實現方式&#xff1a; 中斷咬尾的運作機制 當多個中斷請求連續…

Vue2下載二進制文件

后端&#xff1a; controller: GetMapping(value "/get-import-template")public void problemTemplate(HttpServletRequest request, HttpServletResponse response) throws Exception {iUserService.problemTemplate(request, response);} service: void probl…

Ubuntu小練習

文章目錄 一、遠程連接1、通過putty連接2、查看putty運行狀態3、通過Puuty遠程登錄Ubuntu4、添加新用戶查看是否添加成功 5、用新用戶登錄遠程Ubuntu6、使用VNC遠程登錄樹莓派 二、虛擬機上talk聊天三、Opencv1、簡單安裝版&#xff08;適合新手安裝&#xff09;2、打開VScode特…

996引擎-疑難雜癥:Ctrl + F9 編輯好的UI進入游戲查看卻是歪的

Ctrl F9 編輯好UI后&#xff0c;進入游戲查看卻是歪的。 檢查Ctrl F10 是否有做過編輯。可以找到對應界面執行【清空】

WinForm真入門(5)——控件的基類Control

控件的基類–Control 用于 Windows 窗體應用程序的控件都派生自 Control類并繼承了許多通用成員,這些成員都是平時使用控件的過程最常用到的。無論要學習哪個控件的使用&#xff0c;都離不開這些基本成員&#xff0c;尤其是一些公共屬性。由于 Conlrol 類規范了控件的基本特征…

RAG(檢索增強生成)系統,提示詞(Prompt)表現測試(數據說話)

在RAG(檢索增強生成)系統中,評價提示詞(Prompt)設計是否優秀,必須通過量化測試數據來驗證,而非主觀判斷。以下是系統化的評估方法、測試指標和具體實現方案: 一、提示詞優秀的核心標準 優秀的提示詞應顯著提升以下指標: 維度量化指標測試方法事實一致性Faithfulness …

Appium的學習總結-Inspector參數設置和界面使用(5)

環境搭建好后&#xff0c;怎么使用呢&#xff1f; 環境這里使用的是&#xff1a; Appium的Server端GUI 22版本 Inspector需要單獨下載安裝&#xff0c;GUI里并沒有集成。 &#xff08;使用Appium v1.22.0,查看元素信息需要另外安裝下載Appium Inspector&#xff09; 操作&…