Python語言 AI框架:Mindspore
1.模型構建
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
model = Network()
print(model)
2.模型保存
mindspore.save_checkpoint(model, "model.ckpt")
3.模型導出-mindir格式
除Checkpoint外,MindSpore提供了云側(訓練)和端側(推理)統一的中間表示(Intermediate Representation,IR)。可使用export接口直接將模型保存為MindIR。
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
4.加載保存模型
要加載模型權重,需要先創建相同模型的實例,然后使用load_checkpoint和load_param_into_net方法加載參數。
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) # 正確輸出[ ]
5.加載導出模型
nn.GraphCell
僅支持圖模式。
mindspore.set_context(mode=mindspore.GRAPH_MODE)graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
MindIR
同時保存了Checkpoint
和模型結構,因此需要定義輸入Tensor
來獲取輸入shape。
運行結果-模型保存情況如下: