將數據切分成N份,采用NCCL異步通信,讓all_gather+matmul盡量Overlap

將數據切分成N份,采用NCCL異步通信,讓all_gather+matmul盡量Overlap

  • 一.測試數據
  • 二.測試環境
  • 三.普通實現
  • 四.分塊實現

本文演示了如何將數據切分成N份,采用NCCL異步通信,讓all_gather+matmul盡量Overlap

一.測試數據

  • 1.測試規模:8192*8192 world_size=2
  • 2.單算子:all_gather:0.03508s matmul:0.05689s e2e:0.09197s。matmul耗時最長
  • 3.按輸入和權值切分成8份,async_op=True。e2e:0.75ms
  • 4.e2e耗時從91ms縮短到75ms 縮短了17%。耗時為純matmul算子的:1.34倍

二.測試環境

docker run --gpus all --shm-size=32g -ti -e NVIDIA_VISIBLE_DEVICES=all \--privileged --net=host -v $PWD:/home \-w /home --name all_gather_mm \nvcr.io/nvidia/pytorch:23.07-py3 /bin/bash

三.普通實現

tee all_gather_mm_native.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
from torch.profiler import profile
import nvtxdev_type="cuda"
dist.init_process_group(backend='nccl')torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)
shape=(8192,8192)input_tensor=torch.rand((shape[0],shape[1]),dtype=torch.float).to(device)
weight=torch.rand((shape[1],8192),dtype=torch.float).to(device)
all_gather_buffer=torch.zeros((shape[0]*world_size,shape[1]),dtype=torch.float).to(device)for i in range(10):with nvtx.annotate(f"iter:{i}", color="blue"): dist.barrier()t0=time.time()torch.distributed._all_gather_base(all_gather_buffer, input_tensor)dist.barrier()torch.cuda.synchronize()t1=time.time()output = torch.matmul(all_gather_buffer, weight)torch.cuda.synchronize()t2=time.time()if rank==0:print(f"iter:{i} all_gather:{t1-t0:.5f} matmul:{t2-t1:.5f} e2e:{t2-t0:.5f} data:{output.mean()}")
EOF
export NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
export CUDA_VISIBLE_DEVICES="1,3"
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_nativensys profile --stats=true -o all_gather_mm_native.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_native

輸出

iter:0 all_gather:0.03809 matmul:0.84971 e2e:0.88780 data:2047.62548828125
iter:1 all_gather:0.03327 matmul:0.06595 e2e:0.09922 data:2047.62548828125
iter:2 all_gather:0.03720 matmul:0.06082 e2e:0.09802 data:2047.62548828125
iter:3 all_gather:0.03682 matmul:0.05644 e2e:0.09326 data:2047.62548828125
iter:4 all_gather:0.03382 matmul:0.05648 e2e:0.09030 data:2047.62548828125
iter:5 all_gather:0.03404 matmul:0.05635 e2e:0.09039 data:2047.62548828125
iter:6 all_gather:0.03657 matmul:0.05701 e2e:0.09359 data:2047.62548828125
iter:7 all_gather:0.03840 matmul:0.05695 e2e:0.09535 data:2047.62548828125
iter:8 all_gather:0.03721 matmul:0.05685 e2e:0.09406 data:2047.62548828125
iter:9 all_gather:0.03508 matmul:0.05689 e2e:0.09197 data:2047.62548828125

在這里插入圖片描述

四.分塊實現

tee all_gather_mm_tiling.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
import nvtx# 分幾塊
num_blocks = 8dev_type="cuda"
dist.init_process_group(backend='nccl')torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)streams = [torch.cuda.Stream(device=device) for _ in range(num_blocks)]def all_gather_matmul(rank, world_size, input, weight,gathered_buffer,output_buffer, num_blocks, device):input_chunk_size = input.size(0) // num_blocks  # 每塊的大小weight_chunk_size = weight.size(1) // num_blockshandles = []for i in range(num_blocks):with torch.cuda.stream(streams[i]):# 劃分塊并進行 all_gatherinput_chunk = input[i * input_chunk_size: (i + 1) * input_chunk_size]gather_start_idx = i * input_chunk_size * world_size  # 起始索引handle = dist.all_gather_into_tensor(gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size], input_chunk, async_op=True)handles.append((handle, gather_start_idx))outputs = torch.zeros_like(output_buffer)for i in range(num_blocks):with torch.cuda.stream(streams[i]):handle, gather_start_idx = handles[i]handle.wait()  # 等待通信完成# 直接在通信結果上進行矩陣乘法gathered_input = gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size]for j in range(num_blocks):weight_chunk = weight[:, j * weight_chunk_size: (j + 1) * weight_chunk_size]output_chunk = outputs[i * input_chunk_size * world_size: (i + 1) * input_chunk_size * world_size, j * weight_chunk_size: (j + 1) * weight_chunk_size]             # 進行局部矩陣相乘output_chunk.add_(torch.matmul(gathered_input, weight_chunk))torch.cuda.synchronize(device)return outputs# 初始化
input = torch.rand((8192, 8192),dtype=torch.float).to(device) 
weight = torch.rand((8192, 8192),dtype=torch.float).to(device) 
all_gather_buffer = torch.zeros((8192 * world_size, 8192),dtype=torch.float).to(device)for i in range(10):output = torch.zeros(input.size(0) * world_size, weight.size(1),dtype=torch.float,device=device)dist.barrier()t0=time.time()with nvtx.annotate(f"iter:{i}", color="blue"):output = all_gather_matmul(rank, world_size, input, weight,all_gather_buffer,output,num_blocks,device)torch.cuda.synchronize()t1=time.time()if rank == 0:print(f"iter:{i} e2e:{t1-t0:.5f} data:{output.mean()}")
EOFexport NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tilingnsys profile --stats=true -o all_gather_mm_tiling.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tiling

輸出

iter:0 e2e:0.13553 data:2047.62548828125
iter:1 e2e:0.07687 data:2047.62548828125
iter:2 e2e:0.07717 data:2047.62548828125
iter:3 e2e:0.07645 data:2047.62548828125
iter:4 e2e:0.07724 data:2047.62548828125
iter:5 e2e:0.07586 data:2047.62548828125
iter:6 e2e:0.07587 data:2047.62548828125
iter:7 e2e:0.07589 data:2047.62548828125
iter:8 e2e:0.07626 data:2047.62548828125
iter:9 e2e:0.07549 data:2047.62548828125

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/39480.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/39480.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/39480.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

代理IP的10大誤區:區分事實與虛構

在當今的數字時代&#xff0c;代理已成為在線環境不可或缺的一部分。它們的用途廣泛&#xff0c;從增強在線隱私到繞過地理限制。然而&#xff0c;盡管代理無處不在&#xff0c;但仍存在許多圍繞代理的誤解。在本博客中&#xff0c;我們將探討和消除一些最常見的代理誤解&#…

人腦網絡的多層建模與分析

摘要 了解人類大腦的結構及其與功能的關系&#xff0c;對于各種應用至關重要&#xff0c;包括但不限于預防、處理和治療腦部疾病(如阿爾茨海默病或帕金森病)&#xff0c;以及精神疾病(如精神分裂癥)的新方法。結構和功能神經影像學方面的最新進展&#xff0c;以及計算機科學等…

OBS 免費的錄屏軟件

一、下載 obs 【OBS】OBS Studio 的安裝、參數設置和錄屏、攝像頭使用教程-CSDN博客 二、使用 obs & 輸出無黑屏 【OBS任意指定區域錄屏的方法-嗶哩嗶哩】 https://b23.tv/aM0hj8A OBS任意指定區域錄屏的方法_嗶哩嗶哩_bilibili 步驟&#xff1a; 1&#xff09;獲取區域…

012-GeoGebra基礎篇-構造圓的切線

前邊文章對于基礎內容已經悉數覆蓋了&#xff0c;這一篇我就不放具體的細節&#xff0c;若有需要可以復刻一下 目錄 一、成品展示二、算式內容三、正確性檢查五、文章最后 一、成品展示 二、算式內容 A(0,0) B(3,0) c: Circle(A,B) C(5,4) sSegment(A,C) DMidpoint(s) d: Circ…

k8s部署單節點redis

一、configmap # cat redis-configmap.yaml apiVersion: v1 kind: ConfigMap metadata:name: redis-single-confignamespace: redis data:redis.conf: |daemonize nobind 0.0.0.0port 6379tcp-backlog 511timeout 0tcp-keepalive 300pidfile /data/redis-server.pidlogfile /d…

全網小視頻去水印接口使用說明

一、請求地址&#xff1a; https://www.lytcreate.com/api/qsy/ 二、請求方式&#xff1a;POST 三、請求體&#xff1a;JSON body {"token": "個人中心的token","url": "視頻分享地址"} token獲取地址&#xff0c;訪問&#xff…

uniapp微信小程序使用xr加載模型

1.在根目錄與pages同級創建如下目錄結構和文件&#xff1a; // index.js Component({properties: {modelPath: { // vue頁面傳過來的模型type: String,value: }},data: {},methods: {} }) { // index.json"component": true,"renderer": "xr-frame&q…

Element-plus點擊當前行之后獲取數據顯示跟隨行數據

要實現點擊當前行后&#xff0c;在當前行的下方顯示數據&#xff0c;可以通過以下步驟來實現&#xff1a; 在表格的行點擊事件中獲取當前點擊行的位置信息。根據位置信息動態計算并設置需要顯示數據區域的位置。 下面是一個更新后的示例代碼&#xff0c;演示如何在 Element-P…

Unity 引擎收費模式變革:游戲開發者的挑戰與機遇

Unity 引擎作為游戲開發領域中的重要工具&#xff0c;近日宣布將在 2024 年 1 月 1 日起根據游戲安裝量對開發者進行收費。這一決定引起了業界的廣泛關注和討論。據 Unity 技術博客發布的《Unity 收費模式和配套服務更新》一文&#xff0c;他們選擇這種計費方式是基于每次游戲被…

PHP和phpSpider:如何應對網站變動導致的數據爬取失敗?

php和phpspider&#xff1a;如何應對網站變動導致的數據爬取失敗&#xff1f; 導語&#xff1a; 網絡爬蟲是一種自動化程序&#xff0c;用于從網站上獲取數據并進行處理。PHP是一種廣泛使用的編程語言&#xff0c;而phpSpider是一個基于PHP的開源網絡爬蟲框架。然而&#xff0…

軟降工程學系統實現

一、程序編碼 程序編碼是設計的繼續&#xff0c;將軟件設計的結果翻譯成用某種程序設計語言描述的源代碼。 程序編碼涉及到方法、工具和過程。 程序設計風格和程序設計語言的特性會深刻地影響軟件的質量和可維護性。 要求源程序具有良好的結構性和設計風格。 程序設計風格…

開啟IT世界的探索之旅——致有志于踏入IT領域的高考少年們

高考已成過去&#xff0c;而前方是無限可能的未來。對于那些有志于進入IT領域的高考生來說&#xff0c;這個暑假是你們開啟探索IT世界的絕佳時機。作為一名從事C#軟件開發的專業人員&#xff0c;我希望能通過這篇文章&#xff0c;分享一些學習路線圖和經驗心得&#xff0c;幫助…

【web3】分享一個web入門學習平臺-HackQuest

前言 一直想進入web3行業&#xff0c;但是沒有什么途徑&#xff0c;偶然在電鴨平臺看到HackQuest的共學營&#xff0c;發現真的不錯&#xff0c;并且還接觸到了黑客松這種形式。 鏈接地址&#xff1a;HackQuest 平臺功能 學習路徑&#xff1a;平臺有完整的學習路徑&#xff…

【聊聊原子性,中斷,以及nodejs中的具體示例】

什么是原子性 從一個例子說起&#xff0c; x &#xff0c;讀和寫 &#xff0c; 如圖假設多線程&#xff0c;線程1和線程2同時操作變量x&#xff0c;進行x的操作&#xff0c;那么由于寫的過程中&#xff0c;都會先讀一份x數據到cpu的寄存器中&#xff0c;所以這個時候cpu1 和 c…

MyBatis-plus(下)

目錄 靜態工具 邏輯刪除 枚舉處理器 ?編輯?編輯JSON處理器 分頁插件 案例 靜態工具 只有save與update不需要傳class字節碼 UserController: MyServiceImpl: 改造根據id批量查詢用戶的接口&#xff0c;查詢用戶的同時&#xff0c;查詢出用戶對應的所有地址 Overrid…

容器內存

一、容器內存概述 容器本質上還是一個進程&#xff0c;是一個被隔離和限制的進程。因此容器內存和進程內存在表現形式上其實是一樣的&#xff0c;這塊主要涉及三部分內容&#xff1a;RSS&#xff0c;page cache和swap這三部分&#xff0c;容器基于memory Cgroup對內存進行限制…

用國內鏡像安裝docker 和 docker-compose (ubuntu)

替代方案&#xff0c;改用國內的鏡像站(網易鏡像&#xff09; 1.清除舊版本&#xff08;可選操作&#xff09; for pkg in docker.io docker-doc docker-compose podman-docker containerd runc; do apt-get remove $pkg; done 2.安裝docker apt-get update 首先安裝依賴 apt-g…

Linux驅動開發實戰寶典:設備模型、模塊編程、I2C/SPI/USB外設精講

摘要: 本文將帶你走進 Linux 驅動開發的世界,從設備驅動模型、內核模塊開發基礎開始,逐步深入 I2C、SPI、USB 等常用外設的驅動編寫,結合實際案例,助你掌握 Linux 驅動開發技能。 關鍵詞: Linux 驅動,設備驅動模型,內核模塊,I2C,SPI,USB 一、Linux 設備驅動模型 Li…

mysql創建表的規范

名稱 建表的時候&#xff0c;給表&#xff0c;字段和索引起個好名字 見名知意&#xff1a;好的名字能夠降低溝通和維護的成本名字不宜過長&#xff0c;盡量控制在30個字符以內 大小寫 名字盡量都用小寫字母&#xff0c;因為從視覺上&#xff0c;小寫字母更容易讓人讀懂全部大寫…

Linux嵌入式中MQTT的使用

MQTT是什么&#xff1f; MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息隊列遙測傳輸協議&#xff09;&#xff0c;是一種基于發布/訂閱&#xff08;Publish/Subscribe&#xff09;模式的輕量級通訊協議&#xff0c;該協議構建于TCP/IP協議上&#xff0…