Pytorch FSDP權重分片保存與合并

注:本文章方法只適用Pytorch FSDP1的模型,且切分策略為SHARDED_STATE_DICT場景。

在使用FSDP訓練模型時,為了節省顯存通常會把模型權重也進行切分,在保存權重時為了加速保存通常每個進程各自保存自己持有的部分權重,避免先匯聚到主進程再保存浪費大量時間的問題。保存成分片權重后,如果需要推理則還需要將分片權重進行合并。下面提供了保存分片權重以及將分片權重合并的代碼示例,代碼主要參考accelerate官方源碼。

import osimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utilsdef save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):# refer accelerate/utils/fsdp_utils.py:save_fsdp_modelwith FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):os.makedirs(fsdp_ckpt_path, exist_ok=True)state_dict = {"model": model.state_dict()}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),planner=DefaultSavePlanner(),)def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):# refer accelerate/utils/fsdp_utils.py:merge_fsdp_weightsstate_dict = {}dist_cp_format_utils._load_state_dict(state_dict,storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),no_dist=True,)# To handle if state is a dict like {model: {...}}if len(state_dict.keys()) == 1:state_dict = state_dict[list(state_dict)[0]]torch.save(state_dict, save_path)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/95412.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/95412.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/95412.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

IDEA自動生成Mapper、XML和實體文件

1. 引入插件 <build><finalName>demo</finalName><plugins><plugin><groupId>org.mybatis.generator</groupId><artifactId>mybatis-generator-maven-plugin</artifactId><version>1.3.5</version><depe…

單例模式的理解

目錄單例模式1.餓漢式(線程安全)2.懶漢式(通過synchronized修飾獲取實例的方法保證線程安全)3.雙重校驗鎖的方式實現單例模式4.靜態內部類方式實現單例模式【推薦】單例模式 1.餓漢式(線程安全) package 并發的例子.單例模式; // 餓漢式單例模式&#xff08;天然線程安全&…

NLP---IF-IDF案例分析

一案例 - 紅樓夢1首先準備語料庫http://www.dxsxs.com這個網址去下載2 任務一&#xff1a;拆分提取import os import redef split_hongloumeng():# 1. 配置路徑&#xff08;關鍵&#xff1a;根據實際文件位置修改&#xff09; # 腳本所在文件夾&#xff08;自動獲取&#xff0…

LaTeX(排版系統)Texlive(環境)Vscode(編輯器)環境配置與安裝

LaTeX、Texlive 和 Vscode 三者之間的關系&#xff0c;可以把它們理解成語言、工具鏈和編輯器的配合關系。 1.下載Texlive 華為鏡像網站下載 小編這邊下載的是texlive2025.iso最新版的&#xff0c;下載什么版本看自己需求&#xff0c;只要下載后綴未.iso的即可。為避免錯誤&am…

【深入淺出STM32(1)】 GPIO 深度解析:引腳特性、工作模式、速度選型及上下拉電阻詳解

GPIO 深度解析&#xff1a;引腳特性、工作模式、速度選型及上下拉電阻詳解一、GPIO概述二、GPIO的工作模式1、簡述&#xff08;1&#xff09;4種輸入模式&#xff08;2&#xff09;4種輸出模式&#xff08;3&#xff09;4種最大輸出速度2、引腳速度&#xff08;1&#xff09;輸…

第1節 大模型分布式推理基礎與技術體系

前言:為什么分布式推理是大模型時代的核心能力? 當我們談論大模型時,往往首先想到的是訓練階段的千億參數、千卡集群和數月的訓練周期。但對于商業落地而言,推理階段的技術挑戰可能比訓練更復雜。 2025年,某頭部AI公司推出的130B參數模型在單機推理時面臨兩個選擇:要么…

《軟件工程導論》實驗報告一 軟件工程文檔

目 錄 一、實驗目的 二、實驗環境 三、實驗內容與步驟 四、實驗心得 一、實驗目的 1. 理解軟件工程的基本概念&#xff0c;熟悉軟件&#xff0c;軟件生命周期&#xff0c;軟件生存周期過程和軟件生命周期各階段的定義和內容。 2. 了解軟件工程文檔的類別、內容及撰寫軟件工…

基于elk實現分布式日志

1.基本介紹 1.1 什么是分布式日志 在分布式應用中&#xff0c;日志被分散在儲存不同的設備上。如果你管理數十上百臺服務器&#xff0c;你還在使用依次登錄每臺機器的傳統方法查閱日志。這樣是不是感覺很繁瑣和效率低下。所以我們使用集中化的日志管理&#xff0c;分布式日志…

多模態RAG賽題實戰之策略優化--Datawhale AI夏令營

科大訊飛AI大賽&#xff08;多模態RAG方向&#xff09; - Datawhale 項目流程圖 1、升級數據解析方案&#xff1a;從 fitz 到 MinerU PyMuPDF&#xff08;fitz&#xff09;是基于規則的方式提取pdf里面的數據&#xff1b;MinerU是基于深度學習模型通過把PDF內的頁面看成是圖片…

09--解密棧與隊列:數據結構核心原理

1. 棧 1.1. 棧的簡介 棧 是一種 特殊的線性表&#xff0c;具有數據 先進后出 特點。 注意&#xff1a; stack本身 不支持迭代器操作 主要原因是因為stack不支持數據的隨機訪問&#xff0c;必須保證數據先進后出的特點。stack在CPP庫中實現為一種 容器適配器 所謂容器適配器&a…

打造專屬 React 腳手架:從 0 到 1 開發 CLI 工具

前言: 在前端開發中&#xff0c;重復搭建項目環境是個低效的事兒。要是團隊技術棧固定&#xff08;比如 React AntD Zustand TS &#xff09;&#xff0c;每次從零開始配路由、狀態管理、UI 組件&#xff0c;既耗時又容易出錯。這時候&#xff0c;自定義 CLI 腳手架 就派上…

Python day43

浙大疏錦行 Python day43 import torch import numpy as np import pandas as pd import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Da…

python基于Hadoop的超市數據分析系統

前端開發框架:vue.js 數據庫 mysql 版本不限 后端語言框架支持&#xff1a; 1 java(SSM/springboot)-idea/eclipse 2.NodejsVue.js -vscode 3.python(flask/django)–pycharm/vscode 4.php(thinkphp/laravel)-hbuilderx 數據庫工具&#xff1a;Navicat/SQLyog等都可以 摘要&…

如何用 COLMAP 制作 Blender 格式的數據集

如何用 COLMAP 制作 Blender 格式的數據集并劃分出 transforms_train.json、transforms_val.json 和 transforms_test.json。 一、什么是 Blender 格式數據集? Blender 格式數據集是 Nerf 和 Nerfstudio 常用的輸入格式,其核心是包含了相機內外參的 JSON 文件,一般命名為:…

[GESP202309 六級] 2023年9月GESP C++六級上機題題解,附帶講解視頻!

本文為GESP 2023年9月 六級的上機題目詳細題解和講解視頻&#xff0c;覺得有幫助或者寫的不錯可以點個贊。 題目一講解視頻 GESP2023年9月六級上機題一題目二講解視頻 題目一:小羊買飲料 B3873 [GESP202309 六級] 小楊買飲料 - 洛谷 題目大意: 現在超市一共有n種飲料&#…

linux 操作ppt

目錄 方法1&#xff1a;用 libreoffice 打開PPT文件 播放腳本&#xff1a; 方法2&#xff1a;用 python-pptx 創建和編輯PPT 方法3&#xff1a;其他方法 在Linux中&#xff0c;可以使用Python通過python-pptx庫來創建和編輯PPT文件&#xff0c;但直接播放PPT文件需要借助其…

元數據管理與數據治理平臺:Apache Atlas 基本搜索 Basic Search

文中內容僅限技術學習與代碼實踐參考&#xff0c;市場存在不確定性&#xff0c;技術分析需謹慎驗證&#xff0c;不構成任何投資建議。 Apache Atlas 框架是一套可擴展的核心基礎治理服務&#xff0c;使企業能夠有效、高效地滿足 Hadoop 中的合規性要求&#xff0c;并支持與整個…

LangChain4J-(1)-Hello World

一、LangChain4J是什么&#xff1f; LangChain4J 是一個專為 Java 生態系統設計的開源框架&#xff0c;用于簡化與大語言模型&#xff08;LLM&#xff0c;如 OpenAI 的 GPT 系列、Google 的 Gemini、Anthropic 的 Claude 等&#xff09;的集成和交互。它借鑒了 Python 生態中 L…

HTTPS應用層協議-中間攻擊人

HTTPS應用層協議-中間攻擊人 ? Man-in-the-MiddleAttack&#xff0c;簡稱“MITM 攻擊” 確實&#xff0c;在方案 2/3/4 中&#xff0c;客戶端獲取到公鑰 S 之后&#xff0c;對客戶端形成的對稱秘鑰 X 用服務端給客戶端的公鑰 S 進行加密&#xff0c;中間人即使竊取到了數據&am…

利用 Makefile 高效啟動 VIVADO 軟件:深入解析與實踐

利用 Makefile 高效啟動 VIVADO 軟件&#xff1a;深入解析與實踐 系列文章目錄 1、VMware Workstation Pro安裝指南&#xff1a;詳細步驟與配置選項說明 2、VMware 下 Ubuntu 操作系統下載與安裝指南 3.基于 Ubuntu 的 Linux 系統中 Vivado 2020.1 下載安裝教程 文章目錄利用 …