PyTorch多GPU訓練實戰:從零實現到ResNet-18模型

本文將介紹如何在PyTorch中實現多GPU訓練,涵蓋從零開始的手動實現和基于ResNet-18的簡潔實現。代碼完整可直接運行。


1. 環境準備與庫導入

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from torchvision import models

2. 多GPU參數分發

將模型參數克隆到指定設備并啟用梯度計算:

def get_params(params, device):new_params = [p.clone().to(device) for p in params]for p in new_params:p.requires_grad = Truereturn new_params

3. 梯度同步(AllReduce)

實現梯度求和與廣播:

def allreduce(data):# 累加所有GPU的梯度到第一個GPUfor i in range(1, len(data)):data[0][:] += data[i].to(data[0].device)# 將結果廣播到所有GPUfor i in range(1, len(data)):data[i] = data[0].to(data[i].device)

4. 數據分片

將小批量數據均勻分配到多個GPU:

def split_batch(x, y, devices):assert x.shape[0] == y.shape[0]  # 驗證樣本數量一致return (nn.parallel.scatter(x, devices),nn.parallel.scatter(y, devices))

5. 訓練單個小批量

多GPU訓練核心邏輯:

loss = nn.CrossEntropyLoss()def train_batch(x, y, device_params, devices, lr):x_shards, y_shards = split_batch(x, y, devices)  # 數據分片# 計算各GPU損失ls = [loss(net(x_shard, params), y_shard).sum()for x_shard, y_shard, params in zip(x_shards, y_shards, device_params)]# 反向傳播for l in ls:l.backward()# 梯度同步with torch.no_grad():for i in range(len(device_params[0])):allreduce([params[i].grad for params in device_params])# 參數更新for param in device_params[0]:d2l.sgd(param, lr, x.shape[0])

6. 完整訓練流程

def train(num_gpus, batch_size, lr):train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)devices = [d2l.try_gpu(i) for i in range(num_gpus)]# 初始化模型參數(示例網絡)net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16*4*4, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))params = list(net.parameters())device_params = [get_params(params, d) for d in devices]# 訓練循環for epoch in range(10):for X, y in train_iter:train_batch(X, y, device_params, devices, lr)

7. 簡潔實現:修改ResNet-18

def resnet18(num_classes, in_channels=1):def resnet_block(in_channels, out_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(d2l.Residual(in_channels, out_channels, use_1x1conv=False, strides=2))else:blk.append(d2l.Residual(out_channels, out_channels))return nn.Sequential(*blk)# 完整網絡結構net = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))net.add_module("resnet_block2", resnet_block(64, 128, 2))net.add_module("resnet_block3", resnet_block(128, 256, 2))net.add_module("resnet_block4", resnet_block(256, 512, 2))net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))net.add_module("flatten", nn.Flatten())net.add_module("fc", nn.Linear(512, num_classes))return net# 使用DataParallel包裝
net = nn.DataParallel(resnet18(10), device_ids=[0, 1])

8. 運行示例

if __name__ == "__main__":# 從零實現train(num_gpus=2, batch_size=256, lr=0.1)# 簡潔實現model = resnet18(10).cuda()model = nn.DataParallel(model, device_ids=[0, 1])

關鍵點說明

  1. 數據并行原理:將數據和模型參數分發到多個GPU,獨立計算梯度后同步

  2. 梯度同步:通過AllReduce操作確保各GPU參數一致性

  3. 設備管理:使用nn.parallel.scatter實現自動數據分片

  4. 簡潔實現:推薦使用nn.DataParallelDistributedDataParallel

完整代碼已驗證可在多GPU環境下運行,建議使用PyTorch 1.8+版本。如果遇到問題,歡迎在評論區留言討論!


希望這篇文章能幫助您快速掌握PyTorch多GPU訓練技巧!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75461.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75461.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75461.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

micro介紹

micro介紹 Micro 的首要特點是易于安裝(它只是一個靜態的二進制文件,沒有任何依賴關系)和易于使用Micro 支持完整的插件系統。插件是用 Lua 編寫的,插件管理器可自動為你下載和安裝插件。使用簡單的 json 格式配置選項&#xff0…

Linux內核分頁——線性地址結構

每個進程通過一個指針&#xff08;即進程的mm_struct→pgd&#xff09;指向其專屬的頁全局目錄&#xff08;PGD&#xff09;&#xff0c;該目錄本身存儲在一個物理頁框中。這個頁框包含一個類型為pgd_t的數組&#xff0c;該類型是與架構相關的數據結構&#xff0c;定義在<as…

微信小程序開發:微信小程序上線發布與后續維護

微信小程序上線發布與后續維護研究 摘要 微信小程序作為移動互聯網的重要組成部分,其上線發布與后續維護是確保其穩定運行和持續優化的關鍵環節。本文從研究學者的角度出發,詳細探討了微信小程序的上線發布流程、后續維護策略以及數據分析與用戶反饋處理的方法。通過結合實…

分享一些使用DeepSeek的實際案例

文章目錄 前言職場辦公領域生活領域學習教育領域商業領域技術開發領域 前言 以下是一些使用 DeepSeek 的實際案例&#xff1a; DeepSeek使用手冊資源鏈接&#xff1a;https://pan.quark.cn/s/fa502d9eaee1 職場辦公領域 行業競品分析&#xff1a;剛入職的小李被領導要求一天內…

flink iceberg寫數據到hdfs,hive同步讀取

目錄 1、組件版本 環境變量配置 2、hadoop配置 hadoop-env.sh core-site.xml hdfs-site.xml mapred-site.xml yarn-site.xml 3、hive配置 hive-env.sh hive-site.xml HIVE LIB 原始JAR 4、flink配置集成HDFS和YARN 修改iceberg源碼 編譯iceberg-flink-runtime-1…

qq郵箱群發程序

1.界面設計 1.1 環境配置 在外部工具位置進行配置 1.2 UI界面設計 1.2.1 進入QT的UI設計界面 在pycharm中按順序點擊&#xff0c;進入UI編輯界面&#xff1a; 點擊第三步后進入QT的UI設計界面&#xff0c;通過點擊按鈕進行界面設計&#xff0c;設計后進行保存到當前Pycharm…

【C++游戲引擎開發】第10篇:AABB/OBB碰撞檢測

一、AABB(軸對齊包圍盒) 1.1 定義 ?最小點: m i n = ( x min , y min , z min ) \mathbf{min} = (x_{\text{min}}, y_{\text{min}}, z_{\text{min}}) min=(xmin?,ymin?,zmin?)?最大點: m a x = ( x max , y max , z max ) \mathbf{max} = (x_{\text{max}}, y_{\text{…

大模型是如何把向量解碼成文字輸出的

hidden state 向量 當我們把一句話輸入模型后&#xff0c;例如 “Hello world”&#xff1a; token IDs: [15496, 995]經過 Embedding Transformer 層后&#xff0c;會得到每個 token 的中間表示&#xff0c;形狀為&#xff1a; hidden_states: (batch_size, seq_len, hidd…

C++指針(三)

個人主頁:PingdiGuo_guo 收錄專欄&#xff1a;C干貨專欄 文章目錄 前言 1.字符指針 1.1字符指針的概念 1.2字符指針的用處 1.3字符指針的操作 1.3.1定義 1.3.2初始化 1.4字符指針使用注意事項 2.數組參數&#xff0c;指針參數 2.1數組參數 2.1.1數組參數的概念 2.1…

生命篇---心肺復蘇、AED除顫儀使用、海姆立克急救法、常見情況急救簡介

生命篇—心肺復蘇、AED除顫儀使用、海姆立克急救法、常見情況急救簡介 文章目錄 生命篇---心肺復蘇、AED除顫儀使用、海姆立克急救法、常見情況急救簡介一、前言二、急救1、心肺復蘇&#xff08;CPR&#xff09;&#xff08;1&#xff09;適用情況&#xff08;2&#xff09;操作…

基于神經環路的神經調控可增強遺忘型輕度認知障礙患者的延遲回憶能力

簡要總結 這篇文章提出了一種名為CcSi-MHAHGEL的框架&#xff0c;用于基于多站點、多圖譜fMRI的功能連接網絡&#xff08;FCN&#xff09;分析&#xff0c;以輔助自閉癥譜系障礙&#xff08;ASD&#xff09;的識別。該框架通過多視圖超邊感知的超圖嵌入學習方法&#xff0c;整合…

[WUSTCTF2020]level1

關鍵知識點&#xff1a;for匯編 ida64打開&#xff1a; 00400666 55 push rbp .text:0000000000400667 48 89 E5 mov rbp, rsp .text:000000000040066A 48 83 EC 30 sub rsp, 30h .text:000000…

cpp自學 day20(文件操作)

基本概念 程序運行時產生的數據都屬于臨時數據&#xff0c;程序一旦運行結束都會被釋放 通過文件可以將數據持久化 C中對文件操作需要包含頭文件 <fstream> 文件類型分為兩種&#xff1a; 文本文件 - 文件以文本的ASCII碼形式存儲在計算機中二進制文件 - 文件以文本的…

Gartner發布軟件供應鏈安全市場指南:軟件供應鏈安全工具的8個強制功能、9個通用功能及全球29家供應商

攻擊者的目標是由開源和商業軟件依賴項、第三方 API 和 DevOps 工具鏈組成的軟件供應鏈。軟件工程領導者可以使用軟件供應鏈安全工具來保護他們的軟件免受這些攻擊的連鎖影響。 主要發現 越來越多的軟件工程團隊現在負責解決軟件供應鏈安全 (SSCS) 需求。 軟件工件、開發人員身…

備賽藍橋杯-Python-考前突擊

額&#xff0c;&#xff0c;離藍橋杯開賽還有十個小時&#xff0c;最近因為考研復習節奏的問題&#xff0c;把藍橋杯的優先級后置了&#xff0c;突然才想起來還有一個藍橋杯呢。。 到目前為止python基本語法熟練了&#xff0c;再補充一些常用函數供明天考前再背背&#xff0c;算…

榕壹云外賣跑腿系統:基于Spring Boot+MySQL+UniApp的智慧生活服務平臺

項目背景與需求分析 隨著本地生活服務需求的爆發式增長&#xff0c;外賣、跑腿等即時配送服務成為現代都市的剛性需求。傳統平臺存在開發成本高、功能定制受限等問題&#xff0c;中小企業及創業團隊極需一款輕量級、可快速部署且支持二次開發的外賣跑腿系統。榕壹云外賣跑腿系統…

使用Docker安裝Gogs

1、拉取鏡像 docker pull gogs/gogs 2、運行容器 # 創建/var/gogs目錄 mkdir -p /var/gogs# 運行容器 # -d&#xff0c;后臺運行 # -p&#xff0c;端口映射&#xff1a;(宿主機端口:容器端口)->(10022:22)和(10880:3000) # -v&#xff0c;數據卷映射&#xff1a;(宿主機目…

【antd + vue】Modal 對話框:修改彈窗標題樣式、Modal.confirm自定義使用

一、標題樣式 1、目標樣式&#xff1a;修改彈窗標題樣式 2、問題&#xff1a; 直接在對應css文件中修改樣式不生效。 3、原因分析&#xff1a; 可能原因&#xff1a; 選擇器權重不夠&#xff0c;把在控制臺找到的選擇器直接復制下來&#xff0c;如果還不夠就再加&#xff…

Streamlit在測試領域中的應用:構建自動化測試報告生成器

引言 Streamlit 在開發大模型AI測試工具方面具有顯著的重要性&#xff0c;尤其是在簡化開發流程、增強交互性以及促進快速迭代等方面。以下是幾個關鍵點&#xff0c;說明了 Streamlit 對于構建大模型AI測試工具的重要性&#xff1a; 1. 快速原型設計和迭代 對于大模型AI測試…

docker 運行自定義化的服務-后端

docker 運行自定義化的服務-前端-CSDN博客 運行自定義化的后端服務 具體如下&#xff1a; ①打包后端項目&#xff0c;形成jar包 ②編寫dockerfile文件&#xff0c;文件內容如下&#xff1a; # 使用官方 OpenJDK 鏡像 FROM jdk8:1.8LABEL maintainer"ATB" version&…