MindSpore基礎教程:LeNet-5 神經網絡在MindSpore中的實現與訓練
官方文檔教程使用已經棄用的MindVision模塊,本文是對官方文檔的更新
深度學習在圖像識別領域取得了顯著的成功,LeNet-5 作為卷積神經網絡的經典之作,在諸多研究和應用中占有重要地位。本文將詳細介紹如何使用 MindSpore 框架實現并訓練一個 LeNet-5 神經網絡,專注于處理MNIST手寫數字數據集。
前言
MindSpore 是華為推出的一種新型深度學習框架,旨在為用戶提供高效、易用的編程體驗。接下來,我們將通過實例來展示如何在 MindSpore 中構建、訓練和評估一個經典的 LeNet-5 神經網絡。
環境配置
MindSpore官網
LeNet-5 網絡結構簡介
LeNet-5 是一個簡單的卷積神經網絡,包含兩個卷積層和三個全連接層。它經常被用于圖像識別任務,特別是在處理像 MNIST 這樣的手寫數字數據集時表現出色。
數據集準備與預處理
首先,我們需要準備并預處理數據集。在這個例子中,我們將使用 MNIST 數據集。以下函數 create_dataset
負責加載數據集,并進行必要的預處理:
def create_dataset(data_path, batch_size=32, repeat_size=1):"""創建用于訓練的MNIST數據集。此函數負責加載MNIST數據集,對數據進行預處理和轉換,以便它們可以用于訓練神經網絡。數據預處理包括調整圖像大小、重新縮放和類型轉換。參數:data_path (str): MNIST數據集的路徑。這應該是包含MNIST數據文件的目錄路徑。batch_size (int, 可選): 每個數據批次的大小。默認值為32。repeat_size (int, 可選): 數據集重復的次數。這用于增加數據集的大小。默認值為1。步驟:1. 加載MNIST數據集。2. 對圖像執行大小調整操作,將圖像大小統一調整為32x32像素。3. 對圖像進行重新縮放和標準化處理。先將像素值縮放到0-1之間,然后進行標準化。4. 將圖像的格式從高寬通道(HWC)轉換為通道高寬(CHW)。5. 對標簽進行類型轉換,將其轉換為整型(int32)。6. 對數據集進行洗牌、批處理和重復操作,以準備訓練過程。返回:返回一個處理過的MNIST數據集,可以直接用于模型訓練。注意:- 數據集的預處理步驟對于訓練深度學習模型來說是非常重要的,它們會影響訓練的效果和速度。- 調整batch_size和repeat_size可以影響模型訓練時的內存消耗和速度。"""mnist_dataset = ds.MnistDataset(data_path)resize_operation = vision.Resize((32, 32), interpolation=Inter.LINEAR)rescale_normalization_op = vision.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)rescale_op = vision.Rescale(1.0 / 255.0, 0.0)hwc_to_chw_op = vision.HWC2CHW()type_cast_op = transforms.TypeCast(mstype.int32)mnist_dataset = mnist_dataset.map(input_columns="label", operations=type_cast_op)mnist_dataset = mnist_dataset.map(input_columns="image",operations=[resize_operation, rescale_op, rescale_normalization_op,hwc_to_chw_op])mnist_dataset = mnist_dataset.shuffle(buffer_size=10000)mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)mnist_dataset = mnist_dataset.repeat(repeat_size)return mnist_dataset
這個函數將數據集中的圖像調整為統一的大小,并進行重新縮放和標準化。
構建 LeNet-5 模型
LeNet-5 模型的構建在 LeNet5
類中實現。此類定義了網絡的各層及其排列:
class LeNet5(nn.Cell):"""LeNet-5 神經網絡結構。這是一個經典的卷積神經網絡,通常用于圖像識別任務。它包含了兩個卷積層和三個全連接層。參數:num_class (int): 輸出層的類別數量。默認為10,適用于MNIST數據集。num_channel (int): 輸入圖像的通道數。對于灰度圖像,此值為1。組件:- conv1: 第一個卷積層,使用有效填充。- conv2: 第二個卷積層,同樣使用有效填充。- fc1: 第一個全連接層。- fc2: 第二個全連接層。- fc3: 第三個全連接層,輸出層。- relu: 激活函數,使用ReLU。- max_pool2d: 最大池化層。- flatten: 扁平化層,用于全連接層之前的數據轉換。方法:- construct(x): 定義了前向傳播的過程。"""def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.conv1(x)x = self.relu(x)x = self.max_pool2d(x)x = self.conv2(x)x = self.relu(x)x = self.max_pool2d(x)x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x
訓練模型
接下來,我們定義 train_network
函數來訓練模型。此函數接受模型實例、數據集路徑和其他訓練參數:
def train_network(model, epoch_size, data_path, repeat_size, checkpoint_callback):"""訓練神經網絡模型。此函數負責初始化數據集,然后使用指定的模型進行訓練。在訓練過程中,它將記錄損失并保存模型的檢查點。參數:model (Model): 要訓練的神經網絡模型。epoch_size (int): 訓練過程中遍歷數據集的次數。data_path (str): 訓練數據集的路徑。repeat_size (int): 數據集的重復次數,用于擴充數據集。checkpoint_callback (Callback): 用于保存模型檢查點的回調函數。過程:- 使用 `create_dataset` 函數創建訓練數據集。- 調用模型的 `train` 方法進行訓練。- 在訓練過程中,會通過回調函數記錄損失和保存檢查點。注意:- 確保提供的 `data_path` 包含適當格式的數據。"""print("============== 開始訓練 ==============")ds_train = create_dataset(data_path, 32, repeat_size)model.train(epoch_size, ds_train, callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor()],dataset_sink_mode=False)print("============== 訓練結束 ==============")
主函數
最后,我們通過 train
函數和 parse_arguments
函數將所有步驟串聯起來。train
函數負責初始化模型、損失函數、優化器和檢查點回調,然后調用 train_network
進行訓練:
def train(args):"""初始化并訓練LeNet-5神經網絡模型。此函數設置了網絡模型、損失函數、優化器,并定義了模型檢查點。然后,使用指定的參數調用 `train_network` 函數來進行模型的訓練。參數:args (Namespace): 一個包含訓練參數的命名空間對象。此對象應該包含以下屬性:- epochs (int): 模型訓練的迭代次數。- data_url (str): 訓練數據集的路徑。- output_path (str): 保存模型檢查點的路徑。過程:1. 創建 LeNet-5 網絡實例。2. 定義損失函數為 Softmax Cross-Entropy。3. 定義優化器為 Momentum 優化器。4. 創建模型實例,并指定網絡、損失函數、優化器和評估指標。5. 設置模型檢查點配置。6. 初始化模型檢查點回調函數。7. 調用 `train_network` 函數進行訓練。注意:- 確保 `args` 對象包含正確和完整的訓練參數。- 調整優化器和損失函數的參數可以對訓練結果產生影響。- 模型檢查點將保存在 `args.output_path` 指定的路徑中。"""net = LeNet5()net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)model = Model(net, net_loss, net_opt, metrics={"Accuracy": nn.Accuracy()})config_checkpoint = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.output_path,config=config_checkpoint)train_network(model, args.epochs, args.data_url, 1, checkpoint_callback)
推理
# 加載網絡
param_dict = load_checkpoint("/root/MyCode/pycharm/lenet5/ckpt/checkpoint_lenet-19_1884.ckpt")
network = LeNet5(num_class=NUM_CLASS, num_channel=1) # 用您定義的LeNet5類創建模型實例
load_param_into_net(network, param_dict) # 將參數加載到網絡中
model = Model(network)def predict_digit(img):# 圖像預處理img = cv2.resize(img, (32, 32)) # 調整圖像大小為32x32img = np.array(img, dtype=np.float32) # 轉換圖像數據類型img = (img - 0.1307) / 0.3081 # 對圖像進行標準化處理img = img[np.newaxis, np.newaxis, :, :] # 改變圖像形狀以符合網絡輸入要求(1, 1, 32, 32)# 將圖像數據轉換為MindSpore張量img_tensor = Tensor(img)# 使用模型進行預測output = model.predict(img_tensor)# 將輸出轉換為概率分布probabilities = Softmax()(output)# 獲取每個類別的概率probabilities_np = probabilities.asnumpy()[0]# 將概率轉換為字典格式labels = [str(i) for i in range(10)] # 類別標簽,例如"0", "1", "2", ..., "9"probabilities_dict = {label: prob for label, prob in zip(labels, probabilities_np)}return probabilities_dictgr.Interface(fn=predict_digit,inputs=gr.Image(image_mode='L'),outputs=gr.Label(num_top_classes=NUM_CLASS),live=False,css=".footer {display:none !important}",title="0-9數字畫板",description="畫0-9數字",thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png"
).launch()
結論
通過本文的指南,您可以在 MindSpore 框架中實現并訓練一個經典的 LeNet-5 神經網絡。LeNet-5 在圖像識別任務中展現了卓越的性能,而 MindSpore 的高效和易用性使得深度學習研究和開發更加便捷。您可以根據本文的指導進行實驗,并根據需要調整網絡結構和訓練參數。