深度學習框架:PyTorch使用教程 !!

文章目錄

一、PyTorch框架簡介

1.1 什么是PyTorch

1.2 PyTorch的優勢

二、從入門到精通的PyTorch使用教程

2.1 入門階段

2.1.1 環境安裝與配置

2.1.2 Tensor基礎操作

2.1.3 自動求導(Autograd)

2.1.4 構建神經網絡(nn模塊)

2.1.5 損失函數與優化器

2.2 進階階段

2.2.1 GPU加速與多GPU使用

2.2.2 數據加載與預處理(torch.utils.data)

2.2.3 自定義模型與層

2.2.4 模型調試與可視化

2.2.5 高級訓練技巧

2.3 實戰應用與精通

2.3.1 遷移學習與預訓練模型

2.3.2 分布式訓練和多機訓練

2.3.3 模型優化與調參

2.3.4 實戰項目示例

2.3.5 框架內部源碼閱讀與擴展

三、總結


一、PyTorch框架簡介

1.1 什么是PyTorch

PyTorch是由Facebook的人工智能研究團隊開發的一款開源深度學習框架。它基于Python語言開發,具有易用性、靈活性和高效性,主要特點包括:

動態計算圖:與TensorFlow的靜態圖相比,PyTorch采用動態圖機制(即運行時定義計算圖),便于調試和開發復雜模型。

自動求導:內置強大的自動求導(Autograd)模塊,可以自動計算梯度,極大簡化了反向傳播算法的實現。

豐富的API:提供了張量(Tensor)運算、神經網絡層(nn模塊)、優化器(optim模塊)等豐富的工具和函數,方便快速搭建各種模型。

GPU加速:支持CUDA,可以方便地將數據和模型轉移到GPU上加速運算。

1.2 PyTorch的優勢

靈活性和易用性:由于采用動態圖機制,用戶可以像寫常規Python程序一樣定義和修改網絡結構,非常適合科研探索與實驗。

社區和生態系統:擁有活躍的開發者社區,提供大量的開源模型、工具包和教程。借助TorchVision、TorchText、TorchAudio等擴展庫,可以更方便地進行圖像、文本和音頻的深度學習研究。

調試方便:動態計算圖使得每一步計算都可以實時查看和修改,極大地方便了調試和模型理解。

二、從入門到精通的PyTorch使用教程

本教程將分為入門、進階和實戰應用三個階段,每個階段都有相應的代碼示例與講解。

2.1 入門階段

2.1.1 環境安裝與配置

打開PyTorch官方,選擇合適的版本進行安裝。

官網地址:Start Locally | PyTorch

  • 安裝方式:可以通過 pip 或 conda 安裝
pip?install torch torchvision

或者

conda?install pytorch torchvision cudatoolkit=11.3?-c pytorch
  • 驗證安裝:安裝完成后,在Python環境中輸入以下代碼檢查是否能正常導入:
import?torchprint(torch.__version__)
2.1.2 Tensor基礎操作
  • 創建Tensor:類似于numpy數組,但可以在GPU上運算。
import?torch# 創建一個未初始化的 3x3 張量
x?= torch.empty(3,?3)
print(x)# 創建一個隨機初始化的張量
x?= torch.rand(3,?3)
print(x)# 創建一個全 0 的張量,并指定數據類型為 long
x?= torch.zeros(3,?3, dtype=torch.long)
print(x)
  • Tensor運算:支持加減乘除等多種運算,并且可以與numpy互轉。
x?= torch.rand(3,?3)
y?= torch.rand(3,?3)
# 基本加法
z?= x + y
# numpy 轉換
np_array?= x.numpy()
x_from_np?= torch.from_numpy(np_array)
2.1.3 自動求導(Autograd)
  • 基本概念:利用Autograd模塊,可以自動記錄每一步運算過程,從而在反向傳播時自動計算梯度。
# 定義一個 tensor,并設置 requires_grad=True
x?= torch.ones(2,?2, requires_grad=True)
print(x)# 定義一個簡單運算
y?= x +?2
z?= y * y *?3
out?= z.mean()# 反向傳播計算梯度
out.backward()
print(x.grad)
  • 注意:計算圖在反向傳播后默認會釋放,如果需要多次反向傳播,需要設置 retain_graph=True。
2.1.4 構建神經網絡(nn模塊)

nn.Module:所有神經網絡模型都需要繼承該類。

import?torch.nn?as?nn
import?torch.nn.functional?as?Fclass?Net(nn.Module):def?__init__(self):super(Net, self).__init__()# 定義一個全連接層:輸入維度 784,輸出維度 10self.fc1 = nn.Linear(784,?10)def?forward(self, x):# 將輸入 x 展平成 (batch_size, 784)x = x.view(-1,?784)x = self.fc1(x)return?F.log_softmax(x, dim=1)net = Net()
print(net)
  • 層級組合:可以將多層組合在一起,形成更復雜的網絡結構。
2.1.5 損失函數與優化器
  • 定義損失函數:例如交叉熵損失函數
criterion?= nn.CrossEntropyLoss()
  • 選擇優化器:例如SGD優化器
import?torch.optim as optim
optimizer?= optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  • 訓練循環:
for?epoch?in?range(10):for?data, target?in?train_loader:optimizer.zero_grad() ??# 清空梯度output = net(data)loss = criterion(output, target)loss.backward() ? ? ? ??# 反向傳播optimizer.step() ? ? ? ?# 更新參數print(f"Epoch?{epoch}?finished with loss?{loss.item()}")

2.2 進階階段

2.2.1 GPU加速與多GPU使用
  • 將模型和數據遷移到GPU:
device = torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")
net.to(device)
data, target =?data.to(device), target.to(device)
  • 多GPU并行:利用nn.DataParallel實現模型的多GPU訓練。
if?torch.cuda.device_count() >?1:net?= nn.DataParallel(net)
2.2.2 數據加載與預處理(torch.utils.data)
  • 自定義數據集:繼承 troch.utils.data.Dataset 并重寫 __len__ 與 __getitem__ 方法。
from?torch.utils.data?import?Dataset, DataLoaderclass?MyDataset(Dataset):def?__init__(self, data, labels):self.data = dataself.labels = labelsdef?__len__(self):return?len(self.data)def?__getitem__(self, idx):sample = self.data[idx]label = self.labels[idx]return?sample, labeldataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 常用預處理:使用 torchvision.transforms 對圖像數據進行變換,如裁剪、歸一化、隨機翻轉等。
2.2.3 自定義模型與層
  • 自定義層:除了使用內置的層,也可以根據需求自定義層或模塊。
class?MyLayer(nn.Module):def?__init__(self, in_features, out_features):super(MyLayer,?self).__init__()self.weight = nn.Parameter(torch.randn(in_features, out_features))def?forward(self, x):return?torch.matmul(x,?self.weight)
  • 模塊嵌套:在復雜模型中,可以將子模塊封裝在一起,實現層級化設計。
2.2.4 模型調試與可視化
  • 調試技巧:利用Python調試器(如pdb)或IDE自帶的調試工具,對模型前向傳播、反向傳播過程進行跟蹤。
  • 可視化:使用TensorBoardX或其他可視化工具,監視訓練過程中損失、準確率等指標。
from?torch.utils.tensorboard?import?SummaryWriter
writer =?SummaryWriter(log_dir='./logs')
writer.add_scalar('Loss/train', loss.item(), epoch)
2.2.5 高級訓練技巧
  • 學習率調度:使用torch.optim.lr_scheduler 動態調整學習率,例如StepLR、ReduceLROnPlateau等。
scheduler?= optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
for?epoch in range(10):train(...)scheduler.step()
  • 模型保存與加載:
# 保存模型
torch.save(net.state_dict(),?'model.pth')
# 加載模型
net.load_state_dict(torch.load('model.pth'))
net.eval() ?# 切換到評估模式

2.3 實戰應用與精通

2.3.1 遷移學習與預訓練模型
  • 利用預訓練模型:借助 torchvision.models 中的預訓練模型(如 ResNet、VGG),進行微調(fine-tuning)或特征提取。
import?torchvision.models?as?models
resnet18 = models.resnet18(pretrained=True)
# 凍結部分參數
for?param?in?resnet18.parameters():param.requires_grad =?False
# 修改最后一層
num_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_features, num_classes)
2.3.2 分布式訓練和多機訓練

分布式訓練:利用torch.distributed 包,實現跨GPU、跨節點訓練。常見方法包括:

  • DistributedDataParallel(DDP):在單機或多機多卡訓練時比DataParallel更高效。
  • 使用 launch 工具:例如 torch.distributed.launch 腳本啟動分布式訓練任務。

代碼示例:

import?torch.distributed?as?dist
dist.init_process_group(backend='nccl')
net = nn.parallel.DistributedDataParallel(net)
2.3.3 模型優化與調參
  • 超參數搜索:利用網格搜索、隨機搜索或貝葉斯優化等方法,對學習率、正則化系數等超參數進行調優。
  • 正則化技術:使用 Dropout、Batch Normalization 等方法,提高模型的泛化能力。
  • 混合精度訓練:利用 torch.cuda.amp 實現混合精度訓練,既能提升訓練速度,又能降低顯存占用。
scaler = torch.cuda.amp.GradScaler()
for?data, target?in?train_loader:optimizer.zero_grad()with torch.cuda.amp.autocast():output = net(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
2.3.4 實戰項目示例
  • 圖像分類:利用CIFAR-10、IamgeNet數據集,搭建卷積神經網絡(CNN)進行圖像分類任務。
  • 自然語言處理:使用RNN、LSTM、Transformer等模型解決文本生成、機器翻譯、情感分析等問題。
  • 生成對抗網絡(GAN):構建生成器與判別器,進行圖像生成任務,體驗對抗訓練的全過程。
2.3.5 框架內部源碼閱讀與擴展
  • 源碼學習:深入閱讀PyTorch的核心模塊(如Autograd、nn.Module)源碼,有助于理解其底層實現原理,從而更好地擴展或定制功能。
  • 擴展開發:基于PyTorch自定義C++擴展或Python API,結合高性能計算需求,打造個性化的深度學習工具。

三、總結

  • 入門階段主要掌握 PyTorch 的基本概念、張量操作、自動求導、基本網絡構建及訓練流程;

  • 進階階段深入理解 GPU 加速、數據加載、調試、可視化、學習率調度等技巧,學會自定義模塊和高效訓練;

  • 實戰應用則通過預訓練模型、分布式訓練、混合精度、超參數優化等高級技術,最終達到精通應用 PyTorch 解決實際問題的水平。

參考資料:矩陣空間,作者-碼匠樂樂

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/81226.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/81226.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/81226.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

系統架構設計師:設計模式——創建型設計模式

一、創建型設計模式 創建型模式抽象了實例化過程,它們幫助一個系統獨立于如何創建、組合和表示它的那些對象。一個類創建型模式使用繼承改變被實例化的類,而一個對象創建型模式將實例化委托給另一個對象。 隨著系統演化得越來越依賴于對象復合而不是類…

Dinero.js - 免費開源的 JavaScript 貨幣處理工具庫,完美解決 JS 浮點數精度丟失問題

今天介紹一個在前后端處理貨幣的工具庫,logo 很可愛,是一只藍色的招財小貓。 本文封面圖底圖來自免費 AI 圖庫 StockCake。 Dinero.js 是一個用于貨幣計算的 JavaScript 工具庫,解決開發者在金融、電商、會計等場景中處理貨幣時的精度丟失、…

HNUST湖南科技大學-嵌入式考試選擇題題庫(109道糾正詳解版)

HNUST嵌入式選擇題題庫 1.下面哪點不是嵌入式操作系統的特點。(B) A.內核精簡 B.功能強大 C.專用性強 D.高實時性 解析: 嵌入式操作系統特點是內核精簡、專用性強、高實時性,而"功能強大"通常指的是通用操作系統&#x…

【工具】Windows批量文件復制教程:用BAT腳本自動化文件管理

一、引言 在日常開發與部署過程中,文件的自動化復制是一個非常常見的需求。無論是在構建過程、自動部署,還是備份任務中,開發者經常需要將某個目錄中的 DLL、配置文件、資源文件批量復制到目標位置。相比使用圖形界面的復制粘貼操作&#xf…

xray-poc編寫示例

禁止未授權掃描和測試行為!!! 1. SQL 時間盲注檢測 (Time-Based Blind SQLi) name: generic/time-based-sqli rules:- method: GETpath: "/product?id1 AND (SELECT 1 FROM (SELECT SLEEP(5))a)--"expression: |response.status…

【Day 14】HarmonyOS分布式數據庫實戰

一、分布式數據庫基礎 1. 核心概念速記表 術語解釋示例場景分布式數據庫數據自動同步到同賬號設備手機添加商品→平板立即顯示KV數據模型鍵值對存儲(類似JSON){"cart_item1": {"name":"牛奶","price":10}}數據…

【數據結構】- 棧

前言: 經過了幾個月的漫長歲月,回頭時年邁的小編發現,數據結構的內容還沒有寫博客,于是小編趕緊停下手頭的活動,補上博客以洗清身上的罪孽 目錄 前言: 棧的應用 括號匹配 逆波蘭表達式 數制轉換 棧的實…

TDA4VM SDK J721E (RTOS/Linux) bootloaders梳理筆記

文章目錄 1. 前言2. RTOS BootLoader2.1 引導模式2.2 啟動序列2.2.1 流程框圖2.2.2 Memory map2.3 鏡像格式詳解3. Linux BootLoader鏡像格式詳解啟動流程參考1. 前言 TDA4VM的BootLoader包含兩部分:RTOS的和Linux的。 2. RTOS BootLoader 這是在SoC上的所有內核運行FreeRTO…

Spring Boot + MyBatis-Plus 的現代開發模式

之前的Maven項目和本次需要的環境配置并不一樣 之前使用的是: 傳統的 MyBatis 框架(非 Spring Boot 環境) 手動管理 SqlSession 使用了 .xml 的 Mapper 映射文件 沒有 Spring 容器管理(沒有 Service / RestController 等&…

【Quest開發】極簡版!透視環境下摳出身體并能遮擋身體上的服裝

前兩天發了一個很復雜的版本,又鼓搗了一下發現完全沒有必要。我之前的理解有點偏(不是錯誤的但用法錯了),但是有一些小伙伴收藏了,害怕里面的某些東西對誰有用,所以寫了一篇新的,前兩步配置環境…

vue 常見ui庫對比(element、ant、antV等)

Element UI 1. 簡介 Element UI 是一個基于 Vue 2 和 Vue 3 的企業級 UI 組件庫,提供了豐富的組件和主題定制功能。官方網站:Element UI 2. 主要特點 豐富的組件:包括表單、表格、布局、導航、彈窗等多種組件。主題定制:支持主…

MATLAB畫一把傘

% 傘的參數num_ribs 5; % 傘骨數量修改為5R 1; % 傘的半徑height 0.5; % 傘的高度handle_length 2; % 傘柄長度semicircle_radius 0.26; % 傘柄末端半圓的半徑% 生成傘葉網格theta linspace(0, 2*pi, 100);phi linspace(0, pi/2, 50);[Theta, Phi] meshgrid(theta, phi…

如何在 Go 中實現各種類型的鏈表?

鏈表是動態內存分配中最常見的數據結構之一。它由一組有限的元素組成,每個元素(節點)至少占用兩塊內存:一塊用于存放數據,另一塊用于存放指向下一個節點的指針。本文教程將說明在 Go 語言中如何借助指針和結構體類型來…

新一代機載相控陣雷達的發展

相控陣雷達以其優越的性能在軍事領域中有著廣闊的應用前景,但由于復雜的技術、昂貴的造價使其應用范圍還存在一定的局限性。然而,國內外對相控陣技術的研究非常重視,并取得了豐碩的成果。 軍用相控陣雷達主要分為陸基、海基和空基幾種類型。 …

多數元素題解(LC:169)

169. 多數元素 核心思想(Boyer-Moore 投票算法): 解題思路:可以使用 Boyer-Moore 投票算法、該算法的核心思想是: 維護一個候選元素和計數器、初始時計數器為 0。 遍歷數組: 當計數器為 0 時、設置當前元…

數據庫 AI 助手測評:Chat2DB、SQLFlow 等工具如何提升開發效率?

一、引言:數據庫開發的 “效率革命” 正在發生 在某互聯網金融公司的凌晨故障現場,資深 DBA 正滿頭大汗地排查一條執行超時的 SQL—— 該語句涉及 7 張核心業務表的復雜關聯,因索引缺失導致全表掃描,最終引發交易系統阻塞。這類場景在傳統數據庫開發中屢見不鮮:據 Gartne…

【中間件】bthread效率為什么高?

bthread效率為什么更高? 1 基本概念 bthread是brpc中的用戶態線程(也可稱為M:N線程庫),目的是:提高程序的并發度,同時降低編碼難度,在多核cpu上提供更好的scalability和cache locality。其采用…

DeepSeek V2:引入MLA機制與指令對齊

長上下文革命:Multi-Head Latent Attention(MLA)機制 傳統 Transformer 的多頭注意力需要緩存所有輸入token的 Key 和 Value,這對長文本推理時的內存開銷極為龐大。DeepSeek V2 針對這一難題提出了“Multi-Head Latent Attention”(MLA)機制。MLA 的核心思想是對多頭注意…

Druid監控sql導致的內存溢出--內存分析工具MemoryAnalyzer(mat)

問題 druid監控sql在網頁端顯示&#xff0c;我的服務插入sql比較大&#xff0c;druid把執行過的sql保存在DruidDataSource類的成員變量JdbcDataSourceStat dataSourceStat&#xff1b; JdbcDataSourceStat類中的LinkedHashMap<String, JdbcSqlStat> sqlStatMap中&#…

《Python實戰進階》No45:性能分析工具 cProfile 與 line_profiler

Python實戰進階 No45&#xff1a;性能分析工具 cProfile 與 line_profiler 摘要 在AI模型開發中&#xff0c;代碼性能直接影響訓練效率和資源消耗。本節通過cProfile和line_profiler工具&#xff0c;實戰演示如何定位Python代碼中的性能瓶頸&#xff0c;并結合NumPy向量化操作…