【PyTorch知識點匯總】

PyTorch是一個廣泛使用的深度學習框架,它提供了許多功能強大的工具和函數,用于構建和訓練神經網絡。以下是一些PyTorch的常用知識點和示例說明:

  1. 張量(Tensors)

    • 創建張量:使用torch.tensor()?、torch.Tensor()?或特定創建函數如torch.zeros()?, torch.ones()?, torch.randn()?等創建不同類型的張量。

      import torch
      x = torch.tensor([1., 2., 3.])  # 創建一個浮點型張量
      zeros_tensor = torch.zeros((3, 4))  # 創建一個3x4的全零張量
      
    • 張量操作:類似NumPy,支持各種數學運算和索引操作,如加減乘除、矩陣乘法、廣播機制、切片等。

      y = torch.tensor([4., 5., 6.])
      result = x + y  # 張量加法
      
    • 數據類型轉換:通過.to()?方法可以改變張量的數據類型或者設備(CPU/GPU)。

      device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      x_gpu = x.to(device)  # 將張量移動到GPU上
      
  2. 自動微分(Autograd)

    • 使用.requires_grad_()?標記張量以啟用梯度計算:

      x.requires_grad_()
      y = x * 2
      z = y.sum()
      z.backward()  # 自動計算梯度
      print(x.grad)  # 輸出x的梯度
      
  3. 神經網絡模塊(nn.Module)

    • 定義網絡結構:繼承自nn.Module?并實現__init__?和forward?方法。

      import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.linear = nn.Linear(784, 10)  # 定義一個線性層def forward(self, x):out = self.linear(x)return out
      
    • 構建與訓練模型:

      model = SimpleNet()
      criterion = nn.CrossEntropyLoss()
      optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(num_epochs):# 前向傳播output = model(inputs)loss = criterion(output, targets)# 反向傳播及優化optimizer.zero_grad()loss.backward()optimizer.step()
      
  4. 數據加載器(DataLoader)

    • 使用torch.utils.data.DataLoader?來加載和批處理數據。

      from torch.utils.data import DataLoader, TensorDatasetdataset = TensorDataset(data_tensor, label_tensor)
      dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
      for batch_data, batch_labels in dataloader:# 在每個迭代周期中,batch_data和batch_labels會是當前批次的張量數據pass
      
  5. 保存與加載模型

    • 使用torch.save()?和torch.load()?保存和加載模型參數或整個模型。

      torch.save(model.state_dict(), 'model.pth')  # 保存模型參數
      model.load_state_dict(torch.load('model.pth'))  # 加載模型參數# 或者保存整個模型
      torch.save(model, 'model_full.pth')  # 保存整個模型(包括其結構和參數)
      loaded_model = torch.load('model_full.pth', map_location=device)  # 加載整個模型
      
  6. 多GPU并行訓練

    • 使用torch.nn.DataParallel?或torch.nn.parallel.DistributedDataParallel?進行多GPU訓練。

      model = nn.DataParallel(SimpleNet())  # 如果有多塊GPU可用,則將模型分布到多個GPU上
      
  7. 控制流(autograd with control flow)

    • PyTorch支持在動態圖模式下使用Python的控制流語句(如if-else、for循環),并且能正確跟蹤梯度。

動態計算圖、混合精度訓練、量化壓縮、可視化工具

動態計算圖(Dynamic Computation Graph)
在PyTorch中,計算圖是在運行時構建的,這意味著你可以根據程序運行的狀態實時改變網絡結構或執行不同的計算路徑。這是與靜態計算圖框架如TensorFlow的一個顯著區別。

示例:

# 動態改變模型結構
class DynamicModel(nn.Module):def __init__(self):super(DynamicModel, self).__init__()self.linear1 = nn.Linear(10, 5)self.linear2 = nn.Linear(5, 3)def forward(self, x, use_second_layer=True):out = F.relu(self.linear1(x))if use_second_layer:out = self.linear2(out)  # 根據條件決定是否使用第二層return outmodel = DynamicModel()

混合精度訓練(Mixed Precision Training)
混合精度訓練利用了FP16和FP32數據類型的優勢,通過將部分計算轉移到半精度上以減少內存占用和加快計算速度,同時保持關鍵部分(如梯度更新)在全精度下進行,以維持數值穩定性。

使用torch.cuda.amp?模塊實現自動混合精度訓練:

import torch
from torch.cuda.amp import autocast, GradScalerscaler = GradScaler()for inputs, targets in dataloader:inputs = inputs.cuda()targets = targets.cuda()with autocast():outputs = model(inputs)loss = criterion(outputs, targets)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()

量化壓縮(Quantization)
量化是將模型的權重和激活從浮點數轉換為低比特整數的過程,從而減小模型大小并加速推理。PyTorch提供了量化API來實現這一過程。

簡化版量化示例:

import torch.quantization# 假設model是一個已經訓練好的模型
model_fp32 = ...  # 初始化并訓練模型# 首先對模型進行偽量化(模擬量化)
prepared_model = torch.quantization.prepare(model_fp32)
# 進行量化校準(收集統計數據)
quantized_model = torch.quantization.convert(prepared_model)# 現在quantized_model是一個量化后的模型,可以用于推理

可視化工具
PyTorch支持通過torchviz?庫來進行計算圖可視化,或者配合其他工具(如TensorBoard)展示模型結構、訓練指標等。

對于簡單的計算圖可視化:

from torchviz import make_dotx = torch.randn(5, requires_grad=True)
y = x * 2
z = y ** 2
z.backward(torch.ones_like(z))dot_graph = make_dot(z)
dot_graph.view()  # 在Jupyter Notebook中顯示圖形

對于模型結構可視化,通常結合torchsummary?或直接使用TensorBoard配合torch.utils.tensorboard?接口:

from torchsummary import summarysummary(model, input_size=(1, 28, 28))  # 對于卷積神經網絡,輸入維度為(通道, 高, 寬)# 或者在TensorBoard中展示模型結構
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter()
writer.add_graph(model, torch.rand((1, 28, 28)))  # 輸入一個隨機張量獲取模型結構
writer.close()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/711304.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/711304.shtml
英文地址,請注明出處:http://en.pswp.cn/news/711304.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

面試經典150題——用最少數量的箭引爆氣球

"The only person you are destined to become is the person you decide to be." - Ralph Waldo Emerson 1. 題目描述 2. 題目分析與解析 這個題目開始讀題的時候是有點不好理解題意的,因此我先做個圖讓大家對于題意有更好更直觀的理解再來分析題目。 …

如何使用Portainer創建Nginx容器并搭建web網站發布至公網可訪問【內網穿透】

文章目錄 前言1. 安裝Portainer1.1 訪問Portainer Web界面 2. 使用Portainer創建Nginx容器3. 將Web靜態站點實現公網訪問4. 配置Web站點公網訪問地址4.1公網訪問Web站點 5. 固定Web靜態站點公網地址6. 固定公網地址訪問Web靜態站點 前言 Portainer是一個開源的Docker輕量級可視…

SQL 常見命令及規范

常見命令 1. 查看當前所有數據庫 show databases; 2. 打開指定的庫 use 庫名 ; 3. 查看當前庫的所有表 show tables; 4. 查看其他庫的所有表 show tables from 庫名 ; 5. 創建表 cerate table 表名 ( 列名 列類型, 列名 列類型, ..... …

基于YOLO家族最新模型YOLOv9開發構建自己的個性化目標檢測系統從零構建模型完整訓練、推理計算超詳細教程【以自建數據酸棗病蟲害檢測為例】

在我前面的系列博文中,對于目標檢測系列的任務寫了很多超詳細的教程,目的是能夠讀完文章即可實現自己完整地去開發構建自己的目標檢測系統,感興趣的話可以自行移步閱讀: 《基于官方YOLOv4-u5【yolov5風格實現】開發構建目標檢測模型超詳細實戰教程【以自建缺陷檢測數據集為…

C# OpenVINO Crack Seg 裂縫分割 裂縫檢測

目錄 效果 模型信息 項目 代碼 數據集 下載 C# OpenVINO Crack Seg 裂縫分割 裂縫檢測 效果 模型信息 Model Properties ------------------------- date:2024-02-29T16:35:48.364242 author:Ultralytics task:segment version&…

去掉WordPress網頁圖片默認鏈接功能

既然是wordpress自動添加的,那么我們在上傳圖片到wordpress后臺多媒體的時候,就可以手動改變鏈接指向或者刪除掉,問題是每次都要這么做很麻煩,更別說有忘記的時候。一次性解決這個問題有兩種方法,一種是No Image Link插…

【生成式AI】ChatGPT原理解析(1/3)- 對ChatGPT的常見誤解

Hung-yi Lee 課件整理 文章目錄 誤解1誤解2ChatGPT真正在做的事情-文字接龍 ChatGPT是在2022年12月7日上線的。 當時試用的感覺十分震撼。 誤解1 我們想讓chatGPT講個笑話,可能會以為它是在一個笑話的集合里面隨機地找一個笑話出來。 我們做一個測試就知道不是這樣…

C# Post數據或文件到指定的服務器進行接收

目錄 應用場景 實現原理 實現代碼 PostAnyWhere類 ashx文件部署 小結 應用場景 不同的接口服務器處理不同的應用,我們會在實際應用中將A服務器的數據提交給B服務器進行數據接收并處理業務。 比如我們想要處理一個OFFICE文件,由用戶上傳到A服務器…

中國汽車電子行業發展現狀分析及投資前景預測報告

全版價格:壹捌零零 報告版本:下單后會更新至最新版本 交貨時間:1-2天 第一章 汽車電子相關概述 1.1 汽車的相關介紹 1.1.1 汽車的概念 我國國家最新標準《汽車和掛車類型的術語和定義》(GB/T3730.1—2001&…

基于springboot+vue的貿易行業crm系統

博主主頁:貓頭鷹源碼 博主簡介:Java領域優質創作者、CSDN博客專家、阿里云專家博主、公司架構師、全網粉絲5萬、專注Java技術領域和畢業設計項目實戰,歡迎高校老師\講師\同行交流合作 ?主要內容:畢業設計(Javaweb項目|小程序|Pyt…

Flink分區相關

0、要點 Flink的分區列不會存數據,也就是兩個列有一個分區列,則文件只會存另一個列的數據 1、CreateTable 根據SQL的執行流程,進入TableEnvironmentImpl.executeInternal,createTable分支 } else if (operation instanceof Crea…

Java-nio

一、NIO三大組件 NIO的三大組件分別是Channel,Buffer與Selector Java NIO系統的核心在于:通道(Channel)和緩沖區(Buffer)。通道表示打開到 IO 設備(例如:文件、套接字)的連接。若需要使用 NIO 系統,需要獲取用于連接 IO 設備的通…

Spring的簡單使用及內部實現原理

在現代的Java應用程序開發中,Spring Framework已經成為了不可或缺的工具之一。它提供了一種輕量級的、基于Java的解決方案,用于構建企業級應用程序和服務。本文將介紹Spring的簡單使用方法,并深入探討其內部實現原理。 首先,讓我們…

mysql8.0使用MGR實現高可用

一、三節點MGR集群的安裝部署 1. 安裝準備 準備好下面三臺服務器&#xff1a; IP端口角色192.168.150.213306mgr1192.168.150.223306mgr2192.168.150.233306mgr3 配置hosts解析 # cat >> /etc/hosts << EOF 192.168.150.21 mgr1 192.168.150.22 mgr2 192.168…

Windows環境下的調試器探究——硬件斷點

與軟件斷點與內存斷點不同&#xff0c;硬件斷點不依賴被調試程序&#xff0c;而是依賴于CPU中的調試寄存器。 調試寄存器有7個&#xff0c;分別為Dr0~Dr7。 用戶最多能夠設置4個硬件斷點&#xff0c;這是由于只有Dr0~Dr3用于存儲線性地址。 其中&#xff0c;Dr4和Dr5是保留的…

java中容器繼承體系

首先上圖 源碼解析 打開Collection接口源碼&#xff0c;能夠看到Collection接口是繼承了Iterable接口。 public interface Collection<E> extends Iterable<E> { /** * ...... */ } 以下是Iterable接口源碼及注釋 /** * Implementing this inte…

makefileGDB使用

一、makefile 1、make && makefile makefile帶來的好處就是——自動化編譯&#xff0c;一旦寫好&#xff0c;只需要一個make命令&#xff0c;整個工程完全自動編譯&#xff0c;極大的提高了軟件開發的效率 下面我們通過如下示例來進一步體會它們的作用&#xff1a; ①…

使用 Python 實現一個飛書/微信記賬機器人,酷B了!

Python飛書文檔機器人 今天的主題是&#xff1a;使用Python聯動飛書文檔機器人&#xff0c;實現一個專屬的記賬助手&#xff0c;這篇文章如果對你幫助極大&#xff0c;歡迎你分享給你的朋友、她、他&#xff0c;一起成長。 也歡迎大家留言&#xff0c;說說自己想看什么主題的…

代碼隨想錄第天 78.子集 90.子集II

LeetCode 78 子集 題目描述 給你一個整數數組 nums &#xff0c;數組中的元素 互不相同 。返回該數組所有可能的子集&#xff08;冪集&#xff09;。 解集 不能 包含重復的子集。你可以按 任意順序 返回解集。 示例 1&#xff1a; 輸入&#xff1a;nums [1,2,3] 輸出&…

LeetCode 2581.統計可能的樹根數目:換根DP(樹形DP)

【LetMeFly】2581.統計可能的樹根數目&#xff1a;換根DP(樹形DP) 力扣題目鏈接&#xff1a;https://leetcode.cn/problems/count-number-of-possible-root-nodes/ Alice 有一棵 n 個節點的樹&#xff0c;節點編號為 0 到 n - 1 。樹用一個長度為 n - 1 的二維整數數組 edges…