《昇思 25 天學習打卡營第 10 天 | ResNet50 遷移學習 》
活動地址:https://xihe.mindspore.cn/events/mindspore-training-camp
簽名:Sam9029
使用遷移學習進行狼狗圖像分類
簡介
在機器學習和深度學習中,我們經常面臨數據不足的問題。
遷移學習是一種解決這一問題的有效方法。
本章節將通過一個簡單的案例,介紹如何使用遷移學習對狼和狗的圖像進行分類。
遷移學習概念
遷移學習是一種學習方式,它允許我們將在一個大型數據集(如 ImageNet)上預訓練的模型應用于一個新的、通常較小的數據集。這樣,我們可以利用預訓練模型已經學到的特征,而不必從頭開始訓練整個網絡。
使用模型 ResNet50
- ResNet50 是一種深度卷積神經網絡(CNN)架構, 由微軟研究院的 Kaiming He 等人在 2015 年提出,并在多個視覺識別任務中取得了突破性的性能。
ResNet50 的應用:
圖像分類:ResNet50 可以用于將圖像分類到 1000 個類別中,這是 ImageNet 數據集的標準任務。
物體檢測:通過將 ResNet50 與區域建議網絡(Region Proposal Networks, RPN)結合,可以用于物體檢測任務。
語義分割:ResNet50 也可以用于像素級的圖像理解,即語義分割,其中每個像素都被分類到相應的類別。
數據準備
首先,我們需要下載并準備數據集。在這個案例中,我們使用的是來自 ImageNet 的狼和狗的圖像數據集。
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)
加載和增強數據集
使用 MindSpore 的數據集加載接口ImageFolderDataset
來加載數據,并進行一些圖像增強操作,如隨機裁剪、水平翻轉等。
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as visiondef create_dataset_canidae(dataset_path, usage):"""數據加載"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 數據增強操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# other code ...
訓練模型&&訓練過程
我們選擇 ResNet50 作為基礎模型,并對其進行調整以適應我們的分類任務。
def resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)
- 固定特征進行訓練
- 訓練和評估
- 可視化模型預測
學嘛了,完全是懵懵懂懂的,跑了一篇流程,收獲就是知道了計算機視覺 識別圖片的 過程,
使用 全卷積化 網絡的 深度學習網絡模型,來對圖片進行分類和識別