DeepSpeed簡介及加速模型訓練

DeepSpeed是由微軟開發的開源深度學習優化框架,專注于大規模模型的高效訓練與推理。其核心目標是通過系統級優化技術降低顯存占用、提升計算效率,并支持千億級參數的模型訓練。

官網鏈接:deepspeed
訓練代碼下載:git代碼

一、DeepSpeed的核心作用

  1. 顯存優化與高效內存管理

    • ZeRO(Zero Redundancy Optimizer)技術:通過分片存儲模型狀態(參數、梯度、優化器狀態)至不同GPU或CPU,顯著減少單卡顯存占用。例如,ZeRO-2可將顯存占用降低8倍,支持單卡訓練130億參數模型。
      在這里插入圖片描述

    • Offload技術:將優化器狀態卸載到CPU或NVMe硬盤,擴展至TB級內存,支持萬億參數模型訓練。

    • 激活值重計算(Activation Checkpointing):犧牲計算時間換取顯存節省,適用于長序列輸入。

  2. 靈活的并行策略

    • 3D并行:融合數據并行(DP)、模型并行(張量并行TP、流水線并行PP),支持跨節點與節點內并行組合,適應不同硬件架構。

    • 動態批處理與梯度累積:減少通信頻率,支持超大Batch Size訓練。

  3. 訓練加速與混合精度支持

    • 混合精度訓練:支持FP16/BF16,結合動態損失縮放平衡效率與數值穩定性。

    • 稀疏注意力機制:針對長序列任務優化,執行效率提升6倍。

    • 通信優化:支持MPI、NCCL等協議,降低分布式訓練通信開銷。

  4. 推理優化與模型壓縮

    • 低精度推理:通過INT8/FP16量化減少模型體積,提升推理速度。

    • 模型剪枝與蒸餾:壓縮模型參數,降低部署成本。


二、與pytorch 對比分析

1. 優勢

  • 顯存效率:相比PyTorch DDP,單卡80GB GPU可訓練130億參數模型(傳統方法僅支持約10億)。

  • 并行靈活性:支持3D并行組合,優于Horovod(側重數據并行)和Megatron(側重模型并行)。

  • 生態集成:與Hugging Face Transformers、PyTorch無縫兼容,簡化現有項目遷移。

  • 全流程覆蓋:同時優化訓練與推理,而vLLM僅專注推理優化。

2. 局限性

  • 配置復雜度:分布式訓練需手動調整通信策略和分片參數,學習曲線陡峭(需編寫JSON配置文件)。

  • 硬件依賴:部分高級功能(如ZeRO-Infinity)依賴NVMe硬盤或特定GPU架構。

  • 推理效率:純推理場景下,vLLM的吞吐量更高(連續批處理優化更專精)。


三、訓練用例

1、ds_config.json(deepspeed執行訓練時,使用的配置文件)
  • deepspeed訓練模型時,不需要在代碼中定義優化器,只需要在 json 文件中進行配置即可, json文件內容如下:
{"train_batch_size": 128, //所有GPU上的 單個訓練批次大小 之和"gradient_accumulation_steps": 1, //梯度累積 步數"optimizer": {"type": "Adam", //選擇的 優化器"params": {"lr": 0.00015 //相關學習率大小}},"zero_optimization": { //加速策略"stage":2}
}

2、訓練函數

  • 將模型包裝成 deepspeed 形式
#將模型 包裝成 deepspeed 形式
model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())
  • 使用 deepspeed 包裝后的模型 進行 反向傳播和梯度更新
#使用 deepspeed 進行 反向傳播和梯度更新
#反向傳播
model_engine.backward(loss)#梯度更新
model_engine.step()
  • 完整訓練代碼如下:
'''
使用命令行進行啟動啟動命令如下:
deepspeed ds_train.py --epochs 10 --deepspeed --deepspeed_config ds_config.json
'''import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModelif __name__ == '__main__':#讀取命令行 傳遞的參數parser = argparse.ArgumentParser()parser.add_argument("--local_rank", help = "local device id on current node", type = int, default=0)parser.add_argument("--epochs", type = int, default=1)parser = deepspeed.add_config_arguments(parser)args = parser.parse_args()#獲取數據集train_loader, test_loader = load_data() #數據集加載器中的 batch_size的大小 = (ds_config.json中 train_batch_size/gpu數量)#獲取原始模型model = CustomModel().cuda()#將模型 包裝成 deepspeed 形式model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())loss_fn = torch.nn.CrossEntropyLoss().cuda() # 損失函數(分類任務常用)for i in range(args.epochs):for inputs, labels in train_loader:#前向傳播inputs = inputs.cuda()labels = labels.cuda()outputs = model_engine(inputs)loss = loss_fn(outputs, labels)#使用 deepspeed 進行 反向傳播和梯度更新#反向傳播model_engine.backward(loss)#梯度更新model_engine.step()model_engine.save_checkpoint('./ds_models', i)#模型保存torch.save(model_engine.module.state_dict(),'deepspeed_train_model.pth')
3、模型評估

import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 1. 定義數據轉換(預處理)
transform = transforms.Compose([transforms.ToTensor(),          # 轉為Tensor格式(自動歸一化到0-1)transforms.Normalize((0.1307,), (0.3081,))  # 標準化(MNIST的均值和標準差)
])test_data = datasets.MNIST(root='./data',train=False,          # 測試集transform=transform)#獲取數據集
train_loader, test_loader = load_data()model = CustomModel()
model.load_state_dict(torch.load('deepspeed_train_model.pth'))#評估
model.eval()  # 設置為評估模式
correct = 0
total = 0with torch.no_grad():  # 不計算梯度(節省內存)for images, labels in test_loader:images, labels = images, labelsoutputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 取概率最大的類別total += labels.size(0)correct += (predicted == labels).sum().item()print(f"測試集準確率: {100 * correct / total:.2f}%")# 隨機選擇一張測試圖片
index = np.random.randint(0,1000)  # 可以修改這個數字試不同圖片
test_image, true_label = test_data[index]
test_image = test_image.unsqueeze(0)  # 增加批次維度# 預測
with torch.no_grad():output = model(test_image)
predicted_label = torch.argmax(output).item()print(f"預測: {predicted_label}, 真實: {true_label}")# 顯示結果
plt.imshow(test_image.cpu().squeeze(), cmap='gray')
plt.title(f"預測: {predicted_label}, 真實: {true_label}")
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/83903.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/83903.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/83903.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

集星獺 | 重塑集成體驗:新版編排重構仿真電商訂單數據入庫

概要介紹 新版服務編排以可視化模式驅動電商訂單入庫流程升級,實現訂單、客戶、庫存、發票、發貨等環節的自動化處理。流程中通過循環節點、判斷邏輯與數據查詢的編排,完成了低代碼構建業務邏輯,極大提升訂單處理效率與業務響應速度。 背景…

AMO——下層RL與上層模仿相結合的自適應運動優化:讓人形行走操作(loco-manipulation)兼顧可行性和動力學約束

前言 自從去年24年Q4,我司「七月在線」側重具身智能的場景落地與定制開發之后 去年Q4,每個月都會進來新的具身需求今年Q1,則每周都會進來新的具身需求Q2的本月起,一周不止一個需求 特別是本周,幾乎每天都有國企、央企…

MATLAB中進行語音信號分析

在MATLAB中進行語音信號分析是一個涉及多個步驟的過程,包括時域和頻域分析、加窗、降噪濾波、端點檢測以及特征提取等。 1. 加載和預覽語音信號 首先,你需要加載一個語音信號文件。MATLAB支持多種音頻文件格式,如.wav。 [y, fs] audiorea…

JWT令牌驗證

一、JWT 驗證方式詳解 JWT(JSON Web Token)的驗證核心是確保令牌未被篡改且符合業務規則,主要分為以下步驟: 1. 令牌解析與基礎校驗 收到客戶端傳遞的 JWT 后,首先按 . 分割為三部分:Header、Payload、S…

一文講清python、anaconda的安裝以及pycharm創建工程

軟件下載 Pycharm下載地址: Other Versions - PyCharm anaconda下載地址: https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Windows-x86_64.exe 安裝步驟 一、 Python 解釋器的安裝步驟 安裝目錄介紹: 二、 Anaconda 安裝 2.1 安裝步…

Mac如何允許安裝任何來源軟件?

打開系統偏好設置-安全性與隱私,點擊右下角的解鎖按鈕,選擇允許從任何來源。 如果沒有這一選項,請到打開終端,輸入命令行:sudo spctl --master-disable, 輸入命令后回車,輸入電腦的開機密碼后回車。 返回“…

React Flow 中 Minimap 與 Controls 組件使用指南:交互式小地圖與視口控制定制(含代碼示例)

本文為《React Agent:從零開始構建 AI 智能體》專欄系列文章。 專欄地址:https://blog.csdn.net/suiyingy/category_12933485.html。項目地址:https://gitee.com/fgai/react-agent(含完整代碼示?例與實戰源)。完整介紹…

Windows Ubuntu 目錄映射關系

情況一:你是通過 WSL (Windows Subsystem for Linux) 安裝 Ubuntu 這是最常見的情況。如果你在 Microsoft Store 安裝了 “Ubuntu”,默認就是 WSL。 📁 目錄映射關系如下: 從 Ubuntu(WSL)訪問 Windows&…

雙指針法高效解決「移除元素」問題

雙指針法高效解決「移除元素」問題 雙指針法高效解決「移除元素」問題一、問題描述二、解法解析:雙指針法1. 核心思想2. 算法步驟3. 執行過程示例 三、關鍵點分析四、復雜度分析五、與其他解法的比較1. 快慢指針法2. 本解法的優勢 六、實際應用場景七、總結 雙指針法…

知識圖譜構架

目錄 知識圖譜構架 一、StanfordNLP 和 spaCy 工具介紹 (一)StanfordNLP 主要功能 使用示例 (二)spaCy 主要功能 使用示例 二、CRF 和 BERT 的基本原理和入門 (一)CRF(條件隨機場&…

激光三角測量標定與應用

文章目錄 1,介紹。2,技術原理3,類型。3.1,直射式3.2,斜射式3.3,兩種三角位移傳感器特性的比較 4,什么是光片?5,主要的算子。1,create_sheet_of_light_model2&…

高可用消息隊列實戰:AWS SQS 在分布式系統中的核心解決方案

引言:消息隊列的“不可替代性” 在微服務架構和分布式系統盛行的今天,消息隊列(Message Queue) 已成為解決系統解耦、流量削峰、異步處理等難題的核心組件。然而,傳統的自建消息隊列(如RabbitMQ、Kafka&am…

人工智能核心知識:AI Agent 的四種關鍵設計模式

人工智能核心知識:AI Agent 的四種關鍵設計模式 一、引言 在人工智能領域,AI Agent(人工智能代理)是實現智能行為和決策的核心實體。它能夠感知環境、做出決策并采取行動以完成特定任務。為了設計高效、靈活且適應性強的 AI Age…

平替BioLegend品牌-Elabscience PE Anti-Mouse Foxp3抗體:流式細胞術中的高效工具,助力免疫細胞分析!”

概述 調節性T細胞(Treg)在維持免疫耐受和抑制過度免疫反應中發揮關鍵作用,其標志性轉錄因子Foxp3(Forkhead box P3)是Treg功能研究的重要靶點。Elabscience 推出的抗小鼠Foxp3抗體(3G3-E)&…

編程日志5.13

鄰接表的基礎代碼 #include<iostream> using namespace std; //鄰接表的類聲明 class Graph {private: //結構體EdgeNode表示圖中的邊結點,包含頂點vertex、權重weight和指向下一個邊結點的指針next struct EdgeNode { int vertex; int weight; …

PowerBI 矩陣實現動態行內容(如前后銷售數據)統計數據,以及過濾同時為0的數據

我們有一張活動表 和 一張銷售表 我們想實現如下的效果&#xff0c;當選擇某個活動時&#xff0c;顯示活動前后3天的銷售對比圖&#xff0c;如下&#xff1a; 實現方法&#xff1a; 1.新建一個表&#xff0c;用于顯示列&#xff1a; 2.新建一個度量值&#xff0c;用SELECTEDVA…

Prompt Tuning:高效微調大模型的新利器

Prompt Tuning(提示調優)是什么 Prompt Tuning(提示調優) 是大模型參數高效微調(Parameter-Efficient Fine-Tuning, PEFT)的重要技術之一,其核心思想是通過優化 連續的提示向量(而非整個模型參數)來適配特定任務。以下是關于 Prompt Tuning 的詳細解析: 一、核心概念…

杰發科技AC7840——如何把結構體數據寫到Dflash中

1. 結構體數據被存放在Pflash中 正常情況下&#xff0c;可以看到全局變量的結構體數據被存放在Pflash中 數字部分存在RAM中 2. 最小編程單位 8字節編程&#xff0c;因此如果結構體存放在Dfalsh中&#xff0c;進行寫操作&#xff0c;需要寫8字節的倍數 第一種辦法&#xff1a;…

CSS 選擇器入門

一、CSS 選擇器基礎&#xff1a;快速掌握核心概念 什么是選擇器&#xff1f; CSS 選擇器就像 “網頁元素的遙控器”&#xff0c;用于定位 HTML 中的特定元素并應用樣式。 /* 結構&#xff1a;選擇器 { 屬性: 值; } */ p { color: red; } /* 選擇所有<p>元素&#xff0c;…

Anaconda3安裝教程(附加安裝包)Anaconda詳細安裝教程Anaconda3 最新版安裝教程

多環境隔離 可同時維護生產環境、開發環境、測試環境&#xff0c;例如&#xff1a; conda create -n ml python3.10 # 創建機器學習環境 conda activate ml # 激活環境三、Anaconda3 安裝教程 解壓Anaconda3安裝包 找到下載的 Anaconda3 安裝包&#xff08;.ex…