注:本文章方法只適用Pytorch FSDP1的模型,且切分策略為SHARDED_STATE_DICT
場景。
在使用FSDP訓練模型時,為了節省顯存通常會把模型權重也進行切分,在保存權重時為了加速保存通常每個進程各自保存自己持有的部分權重,避免先匯聚到主進程再保存浪費大量時間的問題。保存成分片權重后,如果需要推理則還需要將分片權重進行合并。下面提供了保存分片權重以及將分片權重合并的代碼示例,代碼主要參考accelerate
官方源碼。
import osimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utilsdef save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):# refer accelerate/utils/fsdp_utils.py:save_fsdp_modelwith FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):os.makedirs(fsdp_ckpt_path, exist_ok=True)state_dict = {"model": model.state_dict()}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),planner=DefaultSavePlanner(),)def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):# refer accelerate/utils/fsdp_utils.py:merge_fsdp_weightsstate_dict = {}dist_cp_format_utils._load_state_dict(state_dict,storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),no_dist=True,)# To handle if state is a dict like {model: {...}}if len(state_dict.keys()) == 1:state_dict = state_dict[list(state_dict)[0]]torch.save(state_dict, save_path)