深入理解 PyTorch:從基礎到高級應用

在深度學習的浪潮中,PyTorch 憑借其簡潔易用、動態計算圖等特性,迅速成為眾多開發者和研究人員的首選框架。本文將深入探討 PyTorch 的核心概念、基礎操作以及高級應用,帶你全面了解這一強大的深度學習工具。?

一、PyTorch 簡介?

PyTorch 是一個基于 Python 的科學計算包,主要用于深度學習領域。它由 Facebook 的 AI 研究小組(FAIR)開發,旨在為深度學習提供一個靈活、高效且易于使用的平臺。PyTorch 具有以下幾個顯著特點:?

  1. 動態計算圖:與 TensorFlow 等框架使用的靜態計算圖不同,PyTorch 采用動態計算圖。這意味著在運行時可以根據條件和循環動態構建計算圖,使得調試更加方便,代碼編寫也更加靈活。例如,在訓練過程中,我們可以根據當前的訓練狀態動態調整網絡結構或計算邏輯。?
  1. Pythonic 風格:PyTorch 的設計理念遵循 Python 的簡潔和直觀風格,易于學習和使用。對于熟悉 Python 的開發者來說,能夠快速上手 PyTorch。其 API 設計也非常符合 Python 的編程習慣,代碼可讀性強。?
  1. 強大的 GPU 支持:PyTorch 能夠充分利用 GPU 的并行計算能力,大幅提升深度學習模型的訓練速度。通過簡單的操作,就可以將數據和模型移動到 GPU 上進行計算。?
  1. 豐富的生態系統:PyTorch 擁有龐大的社區和豐富的工具庫,如 TorchVision(用于計算機視覺任務)、TorchText(用于自然語言處理任務)等,方便開發者快速實現各種深度學習應用。?

二、PyTorch 基礎操作?

1. 張量(Tensor)?

張量是 PyTorch 中最基本的數據結構,類似于 NumPy 中的數組。它可以是一個標量(0 維張量)、向量(1 維張量)、矩陣(2 維張量)或更高維的數組。?

創建張量的方式有多種:?

  • 直接創建:?

TypeScript

取消自動換行復制

import torch?

# 創建一個5x3的未初始化張量?

x = torch.empty(5, 3)?

print(x)?

# 創建一個5x3的隨機初始化張量?

y = torch.rand(5, 3)?

print(y)?

# 創建一個5x3的全0張量,數據類型為long?

z = torch.zeros(5, 3, dtype=torch.long)?

print(z)?

  • 從數據創建:?

TypeScript

取消自動換行復制

# 從Python列表創建張量?

data = [[1, 2], [3, 4]]?

a = torch.tensor(data)?

print(a)?

?

  • 基于現有張量創建:?

?

TypeScript

取消自動換行復制

# 使用現有張量的屬性創建新張量?

b = a.new_ones(5, 3, dtype=torch.double)?

print(b)?

?

# 創建與a相同大小和數據類型的隨機張量?

c = torch.randn_like(a, dtype=torch.float)?

print(c)?

?

張量支持各種數學運算,如加法、減法、乘法等,運算方式與 NumPy 類似:?

?

TypeScript

取消自動換行復制

# 加法運算?

result = y + z?

print(result)?

?

# 另一種加法運算方式?

result = torch.add(y, z)?

print(result)?

?

# 原地加法運算(直接修改z)?

z.add_(y)?

print(z)?

?

2. 自動求導(Autograd)?

Autograd 是 PyTorch 中用于自動計算梯度的模塊。在深度學習中,我們需要通過反向傳播計算梯度來更新模型參數,Autograd 可以自動完成這一過程。?

要使用 Autograd,只需將張量的requires_grad屬性設置為True,表示需要計算該張量的梯度。例如:?

?

TypeScript

取消自動換行復制

x = torch.ones(2, 2, requires_grad=True)?

print(x)?

?

y = x + 2?

print(y)?

?

z = y * y * 3?

out = z.mean()?

print(out)?

?

在上述代碼中,x、y、z和out的requires_grad屬性都為True。通過調用out.backward(),可以自動計算out關于x的梯度:?

?

TypeScript

取消自動換行復制

out.backward()?

print(x.grad)?

?

3. 設備(Device)?

PyTorch 支持在 CPU 和 GPU 上進行計算。通過to()方法,可以將張量和模型移動到指定的設備上。首先需要判斷是否有可用的 GPU:?

?

TypeScript

取消自動換行復制

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")?

print(device)?

?

然后將張量移動到設備上:?

?

TypeScript

取消自動換行復制

x = torch.tensor([1, 2, 3])?

x = x.to(device)?

print(x)?

?

對于模型,也可以使用相同的方法將其移動到設備上:?

?

TypeScript

取消自動換行復制

import torch.nn as nn?

?

model = nn.Linear(10, 2)?

model = model.to(device)?

?

三、PyTorch 神經網絡?

1. 定義神經網絡?

在 PyTorch 中,定義神經網絡通常繼承nn.Module類,并實現__init__和forward方法。__init__方法用于定義網絡層,forward方法用于定義數據的前向傳播過程。?

以下是一個簡單的全連接神經網絡示例:?

?

TypeScript

取消自動換行復制

import torch.nn as nn?

import torch.nn.functional as F?

?

class Net(nn.Module):?

def __init__(self):?

super(Net, self).__init__()?

# 輸入圖像大小為32x32,1個通道,輸出6個特征圖?

self.conv1 = nn.Conv2d(1, 6, 3)?

# 輸入6個特征圖,輸出16個特征圖?

self.conv2 = nn.Conv2d(6, 16, 3)?

# 全連接層,輸入16 * 6 * 6個神經元,輸出120個神經元?

self.fc1 = nn.Linear(16 * 6 * 6, 120)?

self.fc2 = nn.Linear(120, 84)?

self.fc3 = nn.Linear(84, 10)?

?

def forward(self, x):?

# 卷積層 + ReLU激活函數 + 最大池化?

x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))?

x = F.max_pool2d(F.relu(self.conv2(x)), 2)?

# 將張量展平為一維向量?

x = x.view(-1, self.num_flat_features(x))?

x = F.relu(self.fc1(x))?

x = F.relu(self.fc2(x))?

x = self.fc3(x)?

return x?

?

def num_flat_features(self, x):?

size = x.size()[1:] # 除批量維度外的所有維度?

num_features = 1?

for s in size:?

num_features *= s?

return num_features?

?

?

net = Net()?

print(net)?

?

2. 損失函數和優化器?

訓練神經網絡需要定義損失函數和優化器。常見的損失函數有均方誤差損失函數(nn.MSELoss)、交叉熵損失函數(nn.CrossEntropyLoss)等。優化器有隨機梯度下降(torch.optim.SGD)、Adam 優化器(torch.optim.Adam)等。?

?

TypeScript

取消自動換行復制

import torch.optim as optim?

?

# 定義損失函數?

criterion = nn.CrossEntropyLoss()?

?

# 定義優化器?

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)?

?

3. 訓練神經網絡?

訓練神經網絡的一般步驟如下:?

  1. 前向傳播,計算預測值。?
  1. 計算損失。?
  1. 反向傳播,計算梯度。?
  1. 使用優化器更新模型參數。?

?

TypeScript

取消自動換行復制

for epoch in range(2):?

running_loss = 0.0?

for i, data in enumerate(trainloader, 0):?

# 獲取輸入數據和標簽?

inputs, labels = data[0].to(device), data[1].to(device)?

?

# 梯度清零?

optimizer.zero_grad()?

?

# 前向傳播 + 反向傳播 + 優化?

outputs = net(inputs)?

loss = criterion(outputs, labels)?

loss.backward()?

optimizer.step()?

?

# 打印統計信息?

running_loss += loss.item()?

if i % 2000 == 1999:?

print('[%d, %5d] loss: %.3f' %?

(epoch + 1, i + 1, running_loss / 2000))?

running_loss = 0.0?

?

print('Finished Training')?

?

四、PyTorch 高級應用?

1. 預訓練模型?

PyTorch 提供了許多預訓練模型,如 ResNet、VGG、BERT 等。我們可以直接加載這些預訓練模型,并在其基礎上進行微調,以適應特定的任務。?

以加載 ResNet18 預訓練模型為例:?

?

TypeScript

取消自動換行復制

import torchvision.models as models?

?

# 加載預訓練的ResNet18模型?

model = models.resnet18(pretrained=True)?

?

# 凍結所有參數,不進行訓練?

for param in model.parameters():?

param.requires_grad = False?

?

# 修改最后一層全連接層,以適應新的分類任務?

num_ftrs = model.fc.in_features?

model.fc = nn.Linear(num_ftrs, 2)?

?

2. 自定義數據集和數據加載器?

在實際應用中,我們通常需要處理自定義的數據集。通過繼承torch.utils.data.Dataset類,可以創建自定義數據集,并使用torch.utils.data.DataLoader進行數據加載和批量處理。?

?

TypeScript

取消自動換行復制

import torch.utils.data as data?

?

class CustomDataset(data.Dataset):?

def __init__(self, data_list, label_list, transform=None):?

self.data_list = data_list?

self.label_list = label_list?

self.transform = transform?

?

def __len__(self):?

return len(self.data_list)?

?

def __getitem__(self, index):?

data = self.data_list[index]?

label = self.label_list[index]?

if self.transform is not None:?

data = self.transform(data)?

return data, label?

?

?

# 使用示例?

custom_dataset = CustomDataset(data_list, label_list)?

dataloader = data.DataLoader(custom_dataset, batch_size=4, shuffle=True)?

?

3. 分布式訓練?

對于大規模的深度學習任務,分布式訓練可以顯著提高訓練效率。PyTorch 提供了分布式訓練的支持,通過torch.distributed模塊可以實現多機多卡的分布式訓練。?

以下是一個簡單的分布式訓練示例(假設在單機多卡環境下):?

?

TypeScript

取消自動換行復制

import torch.distributed as dist?

import torch.multiprocessing as mp?

?

def train(rank, world_size):?

# 初始化分布式環境?

dist.init_process_group("nccl", rank=rank, world_size=world_size)?

?

# 每個進程創建一個模型和優化器?

model = nn.Linear(10, 2).to(rank)?

optimizer = optim.SGD(model.parameters(), lr=0.001)?

?

# 數據并行包裝模型?

model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])?

?

# 訓練過程?

for epoch in range(2):?

running_loss = 0.0?

for i, data in enumerate(trainloader, 0):?

inputs, labels = data[0].to(rank), data[1].to(rank)?

optimizer.zero_grad()?

outputs = model(inputs)?

loss = criterion(outputs, labels)?

loss.backward()?

optimizer.step()?

running_loss += loss.item()?

print('Rank {} loss: {:.3f}'.format(rank, running_loss))?

?

# 銷毀分布式環境?

dist.destroy_process_group()?

?

?

if __name__ == '__main__':?

world_size = torch.cuda.device_count()?

mp.spawn(train, args=(world_size,), nprocs=world_size)?

?

五、總結?

本文全面介紹了 PyTorch 的核心概念、基礎操作、神經網絡構建以及高級應用。從張量的創建和運算,到自動求導、神經網絡訓練,再到預訓練模型、自定義數據集和分布式訓練,涵蓋了 PyTorch 在深度學習開發中的主要方面。希望通過本文的學習,你能夠對 PyTorch 有更深入的理解,并在實際項目中熟練運用這一強大的深度學習框架。隨著深度學習技術的不斷發展,PyTorch 也在持續更新和完善,未來還會有更多強大的功能和應用場景等待我們去探索和實踐。?

以上博客詳細梳理了 Pytorch 從基礎到進階的知識。如果你對某個部分還想進一步了解,或者有特定的應用場景想探討,歡迎隨時告訴我。?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/84819.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/84819.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/84819.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java 中的 synchronized 與 Lock:深度對比、使用場景及高級用法

💡 前言 在多線程并發編程中,線程安全問題始終是開發者需要重點關注的核心內容之一。Java 提供了多種機制來實現同步控制,其中最常用的兩種方式是: 使用 synchronized 關鍵字使用 java.util.concurrent.locks.Lock 接口&#xf…

Notepad++如何列選

在 Notepad 中,你可以通過 列模式(Column Mode) 進行垂直選擇文本(列選),以下是具體操作方法: 方法 1:鍵盤 鼠標列選 按住 Alt 鍵(或 Alt Shift)。 按住鼠…

華為OD機考-水仙花數Ⅰ-邏輯分析(JAVA 2025B卷)

import java.util.*; public static Integer get(int count,int c){if(count<3||count>7){return -1;}//存儲每位數的最高位……最低位int[] arr new int[count];List<Integer> res new ArrayList<>();for(int i(int) Math.pow(10,count-1);i<(int) Math…

基于 STL+VMD 二次分解的 Informer-LSTM 并行預測模型詳解與案例

一、背景與動機 在時間序列預測中,如電力負荷、風速、交通流量等復雜數據常表現為: 非線性:趨勢+季節+突變+噪聲 多尺度:高頻擾動與低頻變化共存 長時依賴:遠期信息也影響當前預測 傳統模型(如 ARIMA、LSTM)往往無法兼顧全局趨勢建模與局部擾動感知,因此我們提出一種 …

【Linux Learning】SSH連線出現警告:WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

問題&#xff1a;WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now (man-in-the-middle attack)! It is al…

輕量級密碼算法PRESENT的C語言實現(無第三方庫)

一、PRESENT算法介紹 PRESENT是一種超輕量級分組密碼算法&#xff0c;由Bogdanov等人在2007年提出&#xff0c;專門為資源受限環境如RFID標簽和傳感器網絡設計。該算法在硬件實現上僅需1570個門等效電路(GE)&#xff0c;在保持較高安全性的同時實現了極小的硬件占用空間。PRES…

if的簡化書寫,提高執行效率

很多時候可能有下面判斷 if(a0) {b1;} else if(a1) {b0;} 就是ba的反向值&#xff1a; a0;b1&#xff1b; a1;b0; 這時&#xff0c;可以簡化如下&#xff1a; ba^1 使用異或&#xff0c;程序更簡潔&#xff0c;執行效率也更高 其他的也可以類似使用按位異或優化代碼

Vim 調用外部命令學習筆記

Vim 外部命令集成完全指南 文章目錄 Vim 外部命令集成完全指南核心概念理解命令語法解析語法對比 常用外部命令詳解文本排序與去重文本篩選與搜索高級 grep 搜索技巧文本替換與編輯字符處理高級文本處理編程語言處理其他實用命令 范圍操作示例指定行范圍處理復合命令示例 實用技…

bash挖礦木馬事件全景復盤與企業級防御實戰20250612

&#x1f427; CentOS “-bash 挖礦木馬” 事件全景復盤與企業級防御實戰 ?? 作者&#xff1a;Narutolxy | &#x1f4c5; 日期&#xff1a;2025-06-12 | &#x1f3f7;? 標簽&#xff1a;Linux 安全、應急響應、運維加固、實戰復盤 &#x1f4d8; 內容簡介 本文是一場真實…

「Linux中Shell命令」Shell命令基礎

知識點詳細解析 Shell簡介 Shell是Linux操作系統系統中用戶與操作系統內核交互的接口。它既是命令解釋器,負責接收用戶輸入的命令并將其轉換為內核能夠理解的指令,也是一種腳本編程語言。作為Linux操作系統的重要組成部分,Shell扮演著用戶與系統內核之間的"中間人"…

202557讀書筆記|《夢里花落知多少(輕經典)》——有你在的地方才最美

《夢里花落知多少&#xff08;輕經典&#xff09;》作者三毛&#xff0c;物極必反&#xff0c;陰晴圓缺&#xff0c;小滿即萬全么&#xff1f;因為幸福過于滿溢。所以幸福被收走了。 沒有看過太多三毛的作品&#xff0c;給我的感覺她是很敏感&#xff0c;多愁善感及沒有安全感…

對象映射 C# 中 Mapster 和 AutoMapper 的比較

Mapster和AutoMapper是C#領域兩大主流對象映射庫&#xff0c;各具特色。Mapster以高性能著稱&#xff0c;使用表達式樹實現零反射映射&#xff0c;首次編譯后執行效率極高&#xff0c;適合對性能敏感的場景&#xff1b;AutoMapper則提供更豐富的功能集&#xff0c;如條件映射和…

QEMU源碼全解析 —— 塊設備虛擬化(26)

接前一篇文章:QEMU源碼全解析 —— 塊設備虛擬化(25) 本文內容參考: 《趣談Linux操作系統》 —— 劉超,極客時間 《QEMU/KVM源碼解析與應用》 —— 李強,機械工業出版社 Virt

微軟PowerBI考試 PL300-選擇 Power BI 模型框架【附練習數據】

微軟PowerBI考試 PL300-選擇 Power BI 模型框架 20 多年來&#xff0c;Microsoft 持續對企業商業智能 (BI) 進行大量投資。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于無數企業使用的成熟的 BI 數據建模技術。 同樣的技術也是 Power BI 數據…

RED DA認證-EN18031網絡安全常見問題以及解答

Q&#xff1a;RED DA是否對所有無線模塊和設備強制要求&#xff1f; A&#xff1a;是的&#xff0c;RED DA適用于歐盟境內銷售的所有無線設備&#xff0c;包括WWAN、藍牙或Wi-Fi模塊。唯一例外是GNSS模塊&#xff08;僅支持接收功能&#xff0c;無需認證&#xff09;。 Q&…

騰訊開源 ovCompose 跨平臺框架:實現一次跨三端(Android/iOS/鴻蒙)

在移動應用開發領域&#xff0c;跨平臺技術一直是開發者們追求的目標&#xff0c;它能夠幫助企業降低開發成本、提高開發效率&#xff0c;同時保證應用在不同平臺上的一致性體驗。2025 年 6 月 3 日&#xff0c;騰訊視頻團隊迎來了一個重要的里程碑 —— 正式發布 ovCompose 跨…

對3D對象進行形變分析

1&#xff0c;目的 分析3D實例對象相對標準參照物的形變。 一般用于質地較軟的材質&#xff08;例如橡膠&#xff0c;布料&#xff09;查找&#xff0c;檢查等。 標準參考模型 需匹配的實例&#xff1a; 形變后的模型&#xff1a;* 形變后的模型&#xff1a; 實例形變后的…

寶塔面板WordPress中使用Contact Form 7插件收不到郵件的解決方法

如果是寶塔面板的環境下&#xff0c;在WordPress中使用Contact Form 7插件提交表單時顯示成功&#xff0c;但郵箱未收到郵件&#xff0c;可能是由于服務器郵件功能配置問題。以下是幾種常見解決方法&#xff1a; 1. 檢查郵件發送方式 默認情況下&#xff0c;Contact Form 7 使…

Android中的DX、D8、R8

Kotlin 版本所需的 AGP、D8 和 R8 版本 :https://developer.android.google.cn/build/kotlin-support?hlzh_cn R8&#xff1a;https://developer.android.google.cn/tools/retrace?hlzh_cn D8&#xff1a;https://developer.android.google.cn/tools/d8?hlzh_cn 如上圖&…

通義靈碼 AI IDE 上線!智能體+MCP 從手動調用工具過渡到“AI 主動調度資源”

告訴大家一個好消息&#xff0c;通義靈碼發布了 AI 編程 IDE &#xff1a;Lingma IDE &#xff0c;你沒看錯&#xff0c;通義靈碼也推出了自己的 AI IDE 客戶端&#xff0c;不是 AI 編程插件&#xff0c;是 IDE 。 Lingma IDE 是基于 VS Code 開源版本構建的智能代碼編輯器&am…