模型保存與加載
學習心得
- 保存 CheckPoint 格式文件,在模型訓練過程中,可以添加檢查點(CheckPoint)用于保存模型的參數,以便進行推理及再訓練使用。如果想繼續在不同硬件平臺上做推理,可通過網絡和CheckPoint格式文件生成對應的MINDIR、AIR和ONNX格式文件。
可以通過CheckpointConfig對象可以設置CheckPoint的保存策略。model = network() mindspore.save_checkpoint(model, "model.ckpt")
- save_checkpoint_steps表示每隔多少個step保存一次。
- keep_checkpoint_max表示最多保留CheckPoint文件的數量。
- prefix表示生成CheckPoint文件的前綴名。
- directory表示存放文件的目錄。
要加載模型權重,需要先創建相同模型的實例,然后使用from mindspore.train.callback import ModelCheckpoint, CheckpointConfig config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck) model.train(epoch_num, dataset, callbacks=ckpoint_cb)
load_checkpoint
和load_param_into_net
方法加載參數。model = network()param_dict = mindspore.load_checkpoint("model.ckpt")param_not_load, _ = mindspore.load_param_into_net(model, param_dict)print(param_not_load)
param_not_load
是未被加載的參數列表,為空時代表所有參數均加載成功。[]
- 保存和加載MindIR,當有了CheckPoint文件后,如果想繼續在MindSpore Lite端側做推理,需要通過網絡和CheckPoint生成對應的MINDIR格式模型文件。
- 統一表示:MindIR作為MindSpore云側(訓練)和端側(推理)的統一模型文件,同時存儲了網絡結構和權重參數值。這使得MindSpore能夠在不同的硬件平臺上實現一次訓練多次部署的能力。
- 導出MindIR:MindSpore提供了export接口,可以直接將模型保存為MindIR格式。
- 保存模型
model = network() inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32)) mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
- 加載模型
mindspore.set_context(mode=mindspore.GRAPH_MODE) graph = mindspore.load("model.mindir") model = nn.GraphCell(graph) outputs = model(inputs) print(outputs.shape)