數據準備
下載狗與狼分類數據集,數據來自ImageNet,每個分類有大約120張訓練圖像與30張驗證圖像。使用download接口下載數據集,并自動解壓到當前目錄。
全是小狗的圖片
另一邊全是狼的圖片
加載數據集
狼狗數據集提取自ImageNet分類數據集,使用mindspore.dataset.ImageFolderDataset
接口來加載數據集,并進行相關圖像增強操作。
數據集可視化
訓練數據集通過MindSpore的ImageFolderDataset接口加載,返回值為字典。用戶可以通過create_dict_iterator接口創建數據迭代器,使用next迭代訪問數據集。在本章中,每次使用next可獲取18個圖像及標簽數據。
訓練模型
構建Resnet50網絡
固定特征進行訓練
訓練和評估
可視化模型預測
總結
使用遷移學習方法對ImageNet數據集中的狼和狗圖像進行分類的案例。首先介紹了數據集的下載和預處理操作,然后使用ResNet50模型進行訓練和驗證,最后保存了精度最高的模型參數。同時也展示了預測結果的可視化以及固定特征進行訓練的方法。