多GPU訓練

寫在前面

限于財力不足,本機上只有一個 GPU 可供使用,因此這部分的代碼只能夠稍作了解,能夠使用的 GPU 也只有一個。

多 GPU 的數據并行:有幾張卡,對一個小批量數據,有幾張卡就分成幾塊,每個 GPU 分別計算梯度,然后加起來做并行。

從零開始實現

%matplotlib inline
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

簡單網絡

# 初始化模型參數
scale = 0.01
W1 = torch.randn(size=(20, 1, 3, 3)) * scale
b1 = torch.zeros(20)
W2 = torch.randn(size=(50, 20, 5, 5)) * scale
b2 = torch.zeros(50)
W3 = torch.randn(size=(800, 128)) * scale
b3 = torch.zeros(128)
W4 = torch.randn(size=(128, 10)) * scale
b4 = torch.zeros(10)
params = [W1, b1, W2, b2, W3, b3, W4, b4]# 定義模型
def lenet(X, params):h1_conv = F.conv2d(input=X, weight=params[0], bias=params[1])h1_activation = F.relu(h1_conv)h1 = F.avg_pool2d(input=h1_activation, kernel_size=(2, 2), stride=(2, 2))h2_conv = F.conv2d(input=h1, weight=params[2], bias=params[3])h2_activation = F.relu(h2_conv)h2 = F.avg_pool2d(input=h2_activation, kernel_size=(2, 2), stride=(2, 2))h2 = h2.reshape(h2.shape[0], -1)h3_linear = torch.mm(h2, params[4]) + params[5]h3 = F.relu(h3_linear)y_hat = torch.mm(h3, params[6]) + params[7]return y_hat# 交叉熵損失函數
loss = nn.CrossEntropyLoss(reduction='none')

向多個設備分發參數,并通過將模型參數復制到一個GPU:

def get_params(params, device): # 把一個參數復制到另外一個GPU上去new_params = [p.to(device) for p in params]for p in new_params:p.requires_grad_() #對每一個參數都需要計算梯度return new_paramsnew_params = get_params(params, d2l.try_gpu(0))
print('b1 權重:', new_params[1])
print('b1 梯度:', new_params[1].grad)

在這里插入圖片描述
allreduce函數將所有向量相加,并將結果廣播給所有GPU

def allreduce(data):for i in range(1, len(data)):data[0][:] += data[i].to(data[0].device)for i in range(1, len(data)):data[i][:] = data[0].to(data[i].device)data = [torch.ones((1, 2), device=d2l.try_gpu(i)) * (i + 1) for i in range(2)]
print('allreduce之前:\n', data[0], '\n', data[1])
allreduce(data)
print('allreduce之后:\n', data[0], '\n', data[1])

在這里插入圖片描述
將一個小批量數據均勻地分布在多個 GPU 上

data = torch.arange(20).reshape(4, 5)
devices = [torch.device('cuda:0'), torch.device('cuda:1')]
split = nn.parallel.scatter(data, devices)
print('input :', data)
print('load into', devices)
print('output:', split)

在這里插入圖片描述

#@save
def split_batch(X, y, devices):"""將X和y拆分到多個設備上"""assert X.shape[0] == y.shape[0]return (nn.parallel.scatter(X, devices),nn.parallel.scatter(y, devices))

在一個小批量上實現多GPU訓練

def train_batch(X, y, device_params, devices, lr):X_shards, y_shards = split_batch(X, y, devices)# 在每個GPU上分別計算損失ls = [loss(lenet(X_shard, device_W), y_shard).sum()for X_shard, y_shard, device_W in zip(X_shards, y_shards, device_params)]for l in ls:  # 反向傳播在每個GPU上分別執行l.backward()# 將每個GPU的所有梯度相加,并將其廣播到所有GPUwith torch.no_grad():for i in range(len(device_params[0])):allreduce([device_params[c][i].grad for c in range(len(devices))])# 在每個GPU上分別更新模型參數for param in device_params:d2l.sgd(param, lr, X.shape[0]) # 在這里,我們使用全尺寸的小批量

定義訓練模型:

def train(num_gpus, batch_size, lr):train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)devices = [d2l.try_gpu(i) for i in range(num_gpus)]# 將模型參數復制到num_gpus個GPUdevice_params = [get_params(params, d) for d in devices]num_epochs = 10animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])timer = d2l.Timer()for epoch in range(num_epochs):timer.start()for X, y in train_iter:# 為單個小批量執行多GPU訓練train_batch(X, y, device_params, devices, lr)torch.cuda.synchronize()timer.stop()# 在GPU0上評估模型animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(lambda x: lenet(x, device_params[0]), test_iter, devices[0]),))print(f'測試精度:{animator.Y[0][-1]:.2f}{timer.avg():.1f}秒/輪,'f'在{str(devices)}')

在單個 GPU 上運行:
在這里插入圖片描述
增加為 2 個 GPU
在這里插入圖片描述
并行后并沒有變快,可能有以下原因:

  • Data 讀取比較慢
  • GPU 增加了,但是 batch_size 沒有增加

多 GPU 的簡潔實現

import torch
from torch import nn
from d2l import torch as d2l

簡單網絡

#@save
def resnet18(num_classes, in_channels=1):"""稍加修改的ResNet-18模型"""def resnet_block(in_channels, out_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(d2l.Residual(in_channels, out_channels,use_1x1conv=True, strides=2))else:blk.append(d2l.Residual(out_channels, out_channels))return nn.Sequential(*blk)# 該模型使用了更小的卷積核、步長和填充,而且刪除了最大匯聚層net = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU())net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))net.add_module("resnet_block2", resnet_block(64, 128, 2))net.add_module("resnet_block3", resnet_block(128, 256, 2))net.add_module("resnet_block4", resnet_block(256, 512, 2))net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))net.add_module("fc", nn.Sequential(nn.Flatten(),nn.Linear(512, num_classes)))return netnet = resnet18(10)
# 獲取GPU列表
devices = d2l.try_all_gpus()
# 我們將在訓練代碼實現中初始化網絡

訓練

def train(net, num_gpus, batch_size, lr):train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)devices = [d2l.try_gpu(i) for i in range(num_gpus)]def init_weights(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)# 在多個GPU上設置模型net = nn.DataParallel(net, device_ids=devices)trainer = torch.optim.SGD(net.parameters(), lr)loss = nn.CrossEntropyLoss()timer, num_epochs = d2l.Timer(), 10animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])for epoch in range(num_epochs):net.train()timer.start()for X, y in train_iter:trainer.zero_grad()X, y = X.to(devices[0]), y.to(devices[0])l = loss(net(X), y)l.backward()trainer.step()timer.stop()animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(net, test_iter),))print(f'測試精度:{animator.Y[0][-1]:.2f}{timer.avg():.1f}秒/輪,'f'在{str(devices)}')

在單個 GPU 上訓練網絡

train(net, num_gpus=1, batch_size=256, lr=0.1)

在這里插入圖片描述
使用2個GPU進行訓練

train(net, num_gpus=2, batch_size=512, lr=0.2)

在這里插入圖片描述

QA 思考

Q1:驗證集準確率震蕩較大是哪個參數影響最大呢?
A1:lr

Q2:為什么batch_size調的比較小,比如8,精度會一直在0.1左右,一直不怎么變化
A2:因為batch_size調的比較小的時候,lr 不能太大。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/75697.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/75697.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/75697.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

0基礎 | 硬件 | 電源系統 一

降壓電路LDO 幾乎所有LDO都是基于此拓撲結構 圖 拓撲結構 LDO屬于線性電源,通過控制開關管的導通程度實現穩壓,輸出紋波小,無開關噪聲 線性電源,IoutIin,發熱功率P電壓差△U*電流I,轉換效率Vo/Vi LDO不適…

mysql數據庫中getshell的方式總結

mysql數據庫中getshell的方式總結 MySQL版本大于5.0,MySQL 5.0版本以上會創建日志文件,我們通過修改日志文件的全局變量,就可以GetSHELL,下面這篇文章主要給大家介紹了關于mysql數據庫中getshell的方式,需要的朋友可以參考下 outfile和dumpfile寫shell 利用條件 …

基于Python的微博數據采集

摘要 本系統通過逆向工程微博移動端API接口,實現了對熱門板塊微博內容及用戶評論的自動化采集。系統采用Requests+多線程架構,支持遞歸分頁采集和動態請求頭模擬,每小時可處理3000+條數據記錄。關鍵技術特征包括:1)基于max_id的評論分頁遞歸算法 2)HTML標簽清洗正則表達…

WiFi加密協議

目錄 1. 認證(Authentication)? ?1.1 開放系統認證(Open System Authentication)? 1.2 共享密鑰認證(Shared Key Authentication)? ?1.3 802.1X/EAP認證(企業級認證)? ?2. 關聯(Association)? ?3. 加密協議(Security Handshake)? ?整體流程總結?…

MySQL篇(六)MySQL 分庫分表:應對數據增長挑戰的有效策略

MySQL篇(六)MySQL 分庫分表:應對數據增長挑戰的有效策略 MySQL篇(六)MySQL 分庫分表:應對數據增長挑戰的有效策略一、引言二、為什么需要分庫分表2.1 性能瓶頸2.2 存儲瓶頸2.3 高并發壓力 三、分庫分表的方…

極限編程(XP)簡介及其價值觀與最佳實踐

目錄 一、什么是極限編程(XP)二、極限編程的核心價值觀1. 溝通2. 簡單3. 反饋4. 勇氣 三、極限編程的12個最佳實踐1. 結對編程2. 40小時工作制3. 簡單設計4. 代碼規范5. 測試驅動開發(TDD)6. 系統隱喻7. 持續集成8. 重構9. 客戶在…

Java進階-day06:反射、注解與動態代理深度解析

目錄 一、反射機制:Java的自我認知能力 1.1 認識反射 1.2 獲取Class對象 1.3 獲取類的成分 二、注解:Java的元數據機制 2.1 注解概述 2.2 元注解 2.3 注解解析 2.4 注解的實際應用 三、動態代理:靈活的間接訪問機制 3.1 為什么需要…

Nacos注冊中心AP模式核心源碼分析(集群模式)

文章目錄 概述一、客戶端新注冊實例信息在集群間同步二、服務端集群節點信息在集群間同步2.1、DistroMapper2.2、ProtocolManager2.3、ServerListManager2.4、RaftPeerSet 三、客戶端實例狀態信息在集群間同步四、服務端新節點上線同步集群數據 概述 在Nacos集群模式下&#xf…

vscode和cursor對ubuntu22.04的remote ssh和X-Windows的無密碼登錄

這里寫自定義目錄標題 寫在前面需求的描述問題的引出 昨天已使能自動登錄上午我的改變UBUNTU 22.04關閉密碼規則一:修改 /etc/pam.d/common-password 文件二:修改 /etc/security/pwquality.conf 文件方法三:禁用 pam_pwquality.so 模塊 vscod…

論文閱讀:基于增強通用深度圖像水印的混合篡改定位技術 OmniGuard

一、論文信息 論文名稱:OmniGuard: Hybrid Manipulation Localization via Augmented Versatile Deep Image Watermarking作者團隊:北京大學發表會議:CVPR2025論文鏈接:https://arxiv.org/pdf/2412.01615二、動機與貢獻 動機: 隨著生成式 AI 的快速發展,其在圖像編輯領…

一周學會Pandas2 Python數據處理與分析-NumPy數組創建

鋒哥原創的Pandas2 Python數據處理與分析 視頻教程: 2025版 Pandas2 Python數據處理與分析 視頻教程(無廢話版) 玩命更新中~_嗶哩嗶哩_bilibili NumPy數組創建最常用的方式是直接創建, numpy 可以直接創建或者將 python的其他元素轉為 array 對象。 下…

【全球首發】DeepSeek谷歌版1.1.5 - 免費GPT-4級別AI工具

【全球首發】DeepSeek谷歌版1.1.5 - 免費GPT-4級別AI工具 資源簡介 DeepSeek谷歌版1.1.5是目前全球領先的免費AI助手,性能超越國內主流AI產品,提供類似GPT-4的智能體驗。 版本信息 最新版本:1.1.5(2024最新版)應用…

小程序29-事件穿參-mark 自定義數據

小程序進行事件傳參的時候,除了使用 data-*屬性 傳遞參數外,還可以 使用 mark 標記傳遞參數 mark 是一種自定義屬性,可以在組件上添加,用于來識別具體觸發事件的 target 節點。同時 mark 還可以用于承載一些自定義數據 在組件上使…

高級:分布式系統面試題精講

一、引言 分布式系統在現代軟件開發中占據重要地位,其設計和實現需要考慮多個關鍵因素。面試官通過相關問題,考察候選人對分布式系統核心概念的理解、實際應用能力以及在復雜場景下的問題解決能力。本文將深入分析分布式系統的CAP定理、一致性協議、分布…

【Android Studio 下載 Gradle 失敗】

路雖遠行則將至,事雖難做則必成 一、事故現場 下載Gradle下載不下來,沒有Gradle就無法把項目編譯為Android應用。 二、問題分析 觀察發現下載時長三分鐘,進度條半天沒動,說明這個是國外的東西,被墻住了,需…

系統思考:思考的快與慢

在做重大決策之前,什么原因一定要補充碳水化合物?人類的大腦其實有兩套運作模式:系統1:自動駕駛模式,依賴直覺,反應快但易出錯;系統2:手動駕駛模式,理性嚴謹,…

從情感分析到樸素貝葉斯法:基于樸素貝葉斯的情感分析如何讓DeepSeek賦能你的工作?

文章目錄 1.概率論基礎1.1 單事件概率1.2 多事件概率1.3 條件概率1.3.1 多事件概率與條件概率的區別 1.4 貝葉斯定理傳統思維誤區貝葉斯定理計算 2. 樸素貝葉斯法2.1 基本概念2.2 模型2.3 學習策略2.4 優化算法2.5 優化技巧拉普拉斯平滑對數似然 3. 情感分析實戰3.1 流程3.2 模…

獲取inode的完整路徑包含掛載的路徑

一、背景 在之前的博客 缺頁異常導致的iowait打印出相關文件的絕對路徑-CSDN博客 里的 2.2.3 一節和 關于inode,dentry結合軟鏈接及硬鏈接的實驗-CSDN博客 里,我們講到了在內核里通過inode獲取inode對應的絕對路徑的方法。對于根目錄下的文件而言&#…

【51單片機】2-6【I/O口】【電動車簡易防盜報警器實現】

1.硬件 51最小系統繼電器模塊震動傳感器模塊433M無線收發模塊 2.軟件 #include "reg52.h" #include<intrins.h> #define J_ON 1 #define J_OFF 0sbit switcher P1^0;//繼電器 sbit D0_ON P1^1;//433M無線收發模塊 sbit D1_OFF P1^2; sbit vibrate …

leetcode二叉樹刷題調試不方便的解決辦法

1. 二叉樹不易構建 在leetcode中刷題時&#xff0c;如果沒有會員就需要將代碼拷貝到本地的編譯器進行調試。但是leetcode中有一類題可謂是毒瘤&#xff0c;那就是二叉樹的題。 要調試二叉樹有關的題需要根據測試用例給出的前序遍歷&#xff0c;自己構建一個二叉樹&#xff0c;…