《Python實戰進階》第33集:PyTorch 入門-動態計算圖的優勢

第33集:PyTorch 入門-動態計算圖的優勢


摘要

PyTorch 是一個靈活且強大的深度學習框架,其核心特性是動態計算圖機制。本集將帶您探索 PyTorch 的張量操作、自動求導系統以及動態計算圖的特點與優勢,并通過實戰案例演示如何使用 PyTorch 實現線性回歸和構建簡單的圖像分類模型。我們將重點突出 PyTorch 在研究與開發中的靈活性及其在 AI 大模型訓練中的應用。
在這里插入圖片描述


核心概念和知識點

1. 張量操作與自動求導

  • 張量(Tensor):類似于 NumPy 數組,但支持 GPU 加速。
  • 自動求導(Autograd):PyTorch 提供了自動微分功能,能夠高效計算梯度,用于優化模型參數。

2. 動態計算圖的特點與優勢

  • 動態計算圖:PyTorch 的計算圖是在運行時動態構建的,支持即時調試和修改。
  • 靈活性:適合實驗性研究,便于實現復雜的模型架構。
  • 直觀性:代碼執行過程清晰可見,易于理解。

3. 自定義模型與訓練循環

  • 模型定義:通過繼承 torch.nn.Module 自定義模型結構。
  • 訓練循環:手動實現前向傳播、損失計算和反向傳播,提供更細粒度的控制。

4. AI 大模型相關性分析

PyTorch 是目前主流的 AI 大模型框架之一,廣泛應用于 GPT、BERT 等模型的訓練:

  • 分布式訓練支持:通過 torch.distributed 模塊實現多 GPU 和多節點訓練。
  • 生態系統豐富:結合 Hugging Face Transformers 等庫,可快速搭建和訓練大模型。

實戰案例

案例 1:使用 PyTorch 實現線性回歸

背景

線性回歸是最基礎的機器學習任務之一,我們使用 PyTorch 實現一個簡單的線性回歸模型。

代碼實現
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 數據生成
torch.manual_seed(42)
x = torch.linspace(-1, 1, 100).reshape(-1, 1)  # 輸入特征
y = 3 * x + 2 + 0.2 * torch.randn(x.size())   # 帶噪聲的目標值# 定義模型
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(1, 1)  # 單輸入單輸出的線性層def forward(self, x):return self.linear(x)model = LinearRegressionModel()# 定義損失函數和優化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)# 訓練模型
epochs = 100
for epoch in range(epochs):# 前向傳播y_pred = model(x)loss = criterion(y_pred, y)# 反向傳播與優化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")# 可視化結果
predicted = model(x).detach().numpy()
plt.scatter(x.numpy(), y.numpy(), label="Original Data", alpha=0.6)
plt.plot(x.numpy(), predicted, 'r', label="Fitted Line")
plt.legend()
plt.title("Linear Regression with PyTorch")
plt.show()
輸出結果
Epoch 10/100, Loss: 0.0431
...
Epoch 100/100, Loss: 0.0012
可視化

案例 2:構建一個簡單的圖像分類模型

背景

我們使用 CIFAR-10 數據集,構建一個簡單的卷積神經網絡(CNN)進行圖像分類。

代碼實現
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 數據加載與預處理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)# 定義 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = SimpleCNN()# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練模型
for epoch in range(5):  # 僅訓練 5 個 epochrunning_loss = 0.0for i, data in enumerate(trainloader):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:  # 每 100 個 batch 打印一次損失print(f"[{epoch+1}, {i+1}] Loss: {running_loss / 100:.3f}")running_loss = 0.0print("Finished Training")# 測試模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on Test Set: {100 * correct / total:.2f}%")
輸出結果
[1, 100] Loss: 2.123
...
Accuracy on Test Set: 55.25%

總結

PyTorch 的動態計算圖機制使其成為深度學習研究與開發的理想工具。通過本集的學習,我們掌握了如何使用 PyTorch 實現線性回歸和構建簡單的圖像分類模型,并了解了其在靈活性和實驗性方面的優勢。


擴展思考

1. PyTorch 在 AI 大模型訓練中的應用

PyTorch 是訓練 GPT、BERT 等大模型的核心工具之一。其動態計算圖機制使得研究人員能夠快速迭代模型架構,而分布式訓練支持則確保了大模型的高效訓練。

2. PyTorch Lightning 的簡化功能

PyTorch Lightning 是一個高級接口,旨在簡化 PyTorch 的使用。它隱藏了訓練循環的復雜性,同時保留了底層靈活性,特別適合大規模實驗和生產環境。


專欄鏈接:Python實戰進階
下期預告:No34 - 使用 Pandas 高效處理時間序列數據

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/75532.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/75532.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/75532.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

初識哈希表

一、題意 給定一個整數數組 nums 和一個目標值 target,要求你在數組中找出和為目標值的那兩個整數,并返回它們的數組下標。你可以假設每種輸入只會對應一個答案。但是,數組中同一個元素不能使用兩遍。 示例: 給定 nums [2, 7, …

23種設計模式-創建型模式-單例

文章目錄 簡介問題1. 確保一個類只有一個實例2. 為該實例提供全局訪問點 解決方案示例重構前:重構后: 拓展volatile 在單例模式中的雙重作用 總結 簡介 單例是一種創建型設計模式,它可以確保一個類只有一個實例,同時為該實例提供…

python裁剪nc文件數據

問題描述: 若干個nc文件儲存全球的1850-2014年月尺度的mrro數據(或其他數據),從1850-1到2014-12一共1980個月,要提取出最后35年1980.1~2014.12年也就是420個月的數據。 代碼實現 def aaa(input_file,output_file,bianliang,start_index,en…

深入解析 Spring Framework 5.1.8.RELEASE 的源碼目錄結構

深入解析 Spring Framework 5.1.8.RELEASE 的源碼目錄結構 1. 引言 Spring Framework 是 Java 領域最流行的企業級開發框架之一,廣泛用于 Web 開發、微服務架構、數據訪問等場景。本文將深入解析 Spring Framework 5.1.8.RELEASE 的源碼目錄結構,幫助開…

數據清洗:基于python抽取jsonl文件數據字段

基于python抽取目錄下所有“jsonl”格式文件。遍歷文件內某個字段進行抽取并合并。 import os import json import time from tqdm import tqdm # 需要先安裝:pip install tqdmdef process_files():# 設置目錄路徑dir_path r"D:\daku\關鍵詞識別\1623-00000…

Windows 下使用 Docker 部署 Go 應用與 Nginx 詳細教程

一、環境準備 1. 安裝必要軟件 Docker Desktop for Windows 下載地址:Docker Desktop: The #1 Containerization Tool for Developers | Docker 安裝時勾選"使用 WSL 2 引擎"(推薦) WSL 2(Windows Subsystem for Li…

C# .net ai Agent AI視覺應用 寫代碼 改作業 識別屏幕 標注等

C# net deepseek RAG AI開發 全流程 介紹_c# 向量處理 deepseek-CSDN博客 視覺多模態大模型 通義千問2.5-VL-72B AI大模型能看懂圖 看懂了后能干啥呢 如看懂圖 讓Agent 寫代碼 ,改作業,識別屏幕 標注等等。。。 據說是目前最好的免費圖片識別框架 通…

Docker多階段構建:告別臃腫鏡像的終極方案

Docker多階段構建:告別臃腫鏡像的終極方案 你是否遇到過這樣的問題:一個簡單的應用,Docker鏡像卻高達1GB?編譯工具、臨時文件、開發依賴全被打包進去,導致鏡像臃腫且不安全。 多階段構建(Multi-stage Build) 就是為解決這一問題而生——它像搬家時“只帶必需品”,讓生…

大模型應用開發之大模型工作流程

一:大模型的問答工作流程 1.1: 分詞和向量化 如上圖所示,我們如果讓大模型去回答問題,首先我們會輸入一些文字給到大模型,大模型本質上是個數學模型,它是理解不了人類的整句話的,所以它會把我們的對應的句…

SpringMVC 請求處理

SpringMVC 請求處理深度解析:從原理到企業級應用實踐 一、架構演進與核心組件協同 1.1 從傳統Servlet到前端控制器模式 SpringMVC采用前端控制器架構模式,通過DispatcherServlet統一處理請求,相比傳統Servlet的分散處理方式,實…

12屆藍橋杯—貨物擺放

貨物擺放 題目描述 小藍有一個超大的倉庫,可以擺放很多貨物。 現在,小藍有 nn 箱貨物要擺放在倉庫,每箱貨物都是規則的正方體。小藍規定了長、寬、高三個互相垂直的方向,每箱貨物的邊都必須嚴格平行于長、寬、高。 小藍希望所…

Reactor/Epoll為什么可以高性能?

在 Reactor 模式中使用 epoll_wait 實現低 CPU 占用率的核心原理是 ?事件驅動的阻塞等待機制,而非忙等待。以下通過分步驟解析其工作原理和性能優勢: void network_thread() {int epoll_fd epoll_create1(0);epoll_event events[MAX_EVENTS];// 添加U…

批量優化與壓縮 PPT,減少 PPT 文件的大小

我們經常能夠看到有些 PPT 文檔明明沒有多少內容,但是卻占用了很大的空間,存儲和傳輸非常的不方便,這時候通常是因為我們插入了一些圖片/字體等資源文件,這些都可能會導致我們的 PPT 文檔變得非常的龐大,今天就給大家介…

Java基礎 3.22

1.break練習 //1-100之內的數求和&#xff0c;求當和第一次大于20的當前數i public class Break01 {public static void main(String[] args) {int n 0;int count 0;for (int i 1; i < 100; i) {count i;System.out.println("當前和為" count);if (count &g…

高性能MySQL筆記

高性能MySQL筆記 《高性能MySQL》第1章 MySQL架構**第一章核心知識點總結****多選題**多選題答案**答案與詳解總結** 《高性能MySQL》第2章 可靠性程世界中的監控核心知識點多選題答案及解析重點鞏固方向 《高性能MySQL》第3章 Performance Schema**第三章核心知識點總結****多…

導游職業資格考試:從迷茫到清晰的備考指南

當你決定報考導游職業資格考試時&#xff0c;可能會感到有些迷茫&#xff0c;不知道從何處入手。別擔心&#xff0c;這份備考指南將帶你從迷茫走向清晰。? 第一步&#xff0c;全面了解考試。導游職業資格考試分為筆試和面試。筆試的四個科目各有特點&#xff0c;《政策與法律…

【BFS】《BFS 攻克 FloodFill:填平圖形世界的技術密碼》

文章目錄 前言例題一、 圖像渲染二、 島嶼數量三、島嶼的最大面積四、被圍繞的區域 結語 前言 什么是BFS&#xff1f; BFS&#xff08;Breadth - First Search&#xff09;算法&#xff0c;即廣度優先搜索算法&#xff0c;是一種用于圖或樹結構的遍歷算法。以下是其詳細介紹&am…

Linux安裝MySQL數據庫并使用C語言進行數據庫開發

目錄 一、前言 二、安裝VMware運行Ubuntu 1.安裝VMware 2.使用VMware打開Ubuntu 三、配置VMware使用網卡 1.添加NAT網卡 四、Linux下安裝MySQL數據庫 五、安裝MySQL開發庫 六、演示代碼 sql_connect.c sql_connect.h main.c中數據庫相關代碼 結尾 一、前言 由于最…

ROS2 部署大語言模型節點

4GB GPU的DeepSeek-Coder 1.3B模型&#xff0c;并且它已經被量化或優化過。以下是具體的步驟&#xff1a; 安裝必要的依賴項&#xff1a; pip install transformers torch grpcio googleapis-common-protos創建一個新的ROS 2包&#xff1a; cd ~/ros2_ws/src ros2 pkg creat…

本人設計的最完全的光壓發電機模型

雙螺旋轉子光壓發電機結構模型 作者&#xff1a;龔仕成 單位&#xff1a;四川水利職業技術學院電力工程系 日期&#xff1a;2024年3月25日 摘要 本文提出了一種基于梯形螺旋溝槽多層復合材料轉子的光壓發電機結構模型&#xff0c;通過光-機-電協同設計實現高效能量轉換。通…