昇思25天學習打卡營第13天|應用實踐之ResNet50遷移學習

基本介紹

? ? ? ? 今日的應用實踐的模型是計算機實踐領域中十分出名的模型----ResNet模型。ResNet是一種殘差網絡結構,它通過引入“殘差學習”的概念來解決隨著網絡深度增加時訓練困難的問題,從而能夠訓練更深的網絡結構。現很多網絡極深的模型或多或少都受此影響。今日的主要任務是使用ResNet50進行遷移學習,所謂的遷移學習是不從頭到尾訓練一個模型,而是在一個特別大的數據集上面訓練得到一個預訓練模型,然后使用該模型的權重作為初始化參數,最后應用于特定的任務。對特定的任務來說也可以對模型進行訓練,但是需要凍結模型的某些參數,一般只訓練模型的分類器部分,不會訓練特征提取部分。下面我們詳細講講今日的應用實踐。

數據集準備

? ? ? ? 本次使用的數據集是來自ImageNet數據集中抽取出來的狼狗數據集,每個類別大概有120張訓練圖像與30張驗證圖像,數據集可通過華為云OBS和相關API直接下載,然后進行加載。可借助MindSpore提供的API對數據集進行加載,同時,因為數據集在送入模型之前需做些處理,這些可封裝到一個函數內,具體代碼如下:

def create_dataset_canidae(dataset_path, usage):"""數據加載"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 數據增強操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 數據映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_set

通過上述操作后,我們可以很方便的調用數據集,無論是進行訓練還是可視化,都可以。

模型搭建

? ? ? ? 準備好數據集后,自然就是進行模型搭建,ResNet50是一個非常常見的模型,具體模型結構和不同AI框架下的代碼都很容易獲取,MindSpore官方也有相關的實現代碼,我們直接使用,模型的代碼如下:

class ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一個卷積層,輸入channel為3(彩色圖像),輸出channel為64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化層,縮小圖片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各個殘差網絡結構塊定義,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化層self.avg_pool = nn.AvgPool2d()# flattern層self.flatten = nn.Flatten()# 全連接層self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加載預訓練模型download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

訓練模型

????????我們采用定特征進行訓練,需要凍結除最后一層之外的所有網絡層,最后一層其實就是一個分類層,前面是特征提取層,MindSpore可以通過設置?requires_grad == False?凍結參數,以便不在反向傳播中計算梯度,具體代碼如下:

import mindspore as ms
import matplotlib.pyplot as plt
import os
import timenet_work = resnet50(pretrained=True)# 全連接層輸入層的大小
in_channels = net_work.fc.in_channels
# 輸出通道數大小為狼狗分類數2
head = nn.Dense(in_channels, 2)
# 重置全連接層
net_work.fc = head# 平均池化層kernel size為7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化層
net_work.avg_pool = avg_pool# 凍結除最后一層外的所有參數
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False# 定義優化器和損失函數
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 實例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

一切準備妥當后,便可以進行訓練,由于有預訓練模型,模型的訓練速度非常快,比從頭到尾訓練快了好幾倍。訓練了5輪,效果就非常好了:

模型可視化預測

? ? ? ? 有了訓練好的模型,自然要看看訓練得好不好,除了看評價指標,最直觀的就是實際使用一下模型,由于這是圖像分類任務,所以將其可視化。這一整個流程的代碼如下:

def visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全連接層輸入層的大小in_channels = net.fc.in_channels# 輸出通道數大小為狼狗分類數2head = nn.Dense(in_channels, 2)# 重置全連接層net.fc = head# 平均池化層kernel size為7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化層net.avg_pool = avg_pool# 加載模型參數param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)# 加載驗證集的數據進行驗證data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()labels = data["label"].asnumpy()class_name = {0: "dogs", 1: "wolves"}# 預測圖像類別output = model.predict(ms.Tensor(data['image']))pred = np.argmax(output.asnumpy(), axis=1)# 顯示圖像及圖像的預測值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若預測正確,顯示為藍色;若預測錯誤,顯示為紅色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()

調用該函數后,可視化結果如下:可以看出還是很準的

Jupyter運行情況

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/43404.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/43404.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/43404.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數據鏈路層(超詳細)

引言 數據鏈路層是計算機網絡協議棧中的第二層,位于物理層之上,負責在相鄰節點之間的可靠數據傳輸。數據鏈路層使用的信道主要有兩種類型:點對點信道和廣播信道。點對點信道是指一對一的通信方式,而廣播信道則是一對多的通信方式…

算法工程師第五天(● 哈希表理論基礎 ● 242.有效的字母異位詞 ● 349. 兩個數組的交集 ● 202. 快樂數● 1. 兩數之和 )

參考文獻 代碼隨想錄 一、有效的字母異位詞 給定兩個字符串 s 和 t ,編寫一個函數來判斷 t 是否是 s 的字母異位詞。 注意:若 s 和 t 中每個字符出現的次數都相同,則稱 s 和 t 互為字母異位詞。 示例 1: 輸入: s "anagram", …

風險評估:Tomcat的安全配置,Tomcat安全基線檢查加固

「作者簡介」:冬奧會網絡安全中國代表隊,CSDN Top100,就職奇安信多年,以實戰工作為基礎著作 《網絡安全自學教程》,適合基礎薄弱的同學系統化的學習網絡安全,用最短的時間掌握最核心的技術。 這一章節我們需…

grafana數據展示

目錄 一、安裝步驟 二、如何添加喜歡的界面 三、自動添加注冊客戶端主機 一、安裝步驟 啟動成功后 可以查看端口3000是否啟動 如果啟動了就在瀏覽器輸入IP地址:3000 賬號密碼默認是admin 然后點擊 log in 第一次會讓你修改密碼 根據自定義密碼然后就能登錄到界面…

高職物聯網實訓室

一、高職物聯網實訓室建設背景 隨著《中華人民共和國國民經濟和社會發展第十四個五年規劃和2035年遠景目標綱要》的發布,中國正式步入加速數字化轉型的新時代。在數字化浪潮中,物聯網技術作為連接物理世界與數字世界的橋梁,其重要性日益凸顯…

Golang | Leetcode Golang題解之第224題基本計算器

題目&#xff1a; 題解&#xff1a; func calculate(s string) (ans int) {ops : []int{1}sign : 1n : len(s)for i : 0; i < n; {switch s[i] {case :icase :sign ops[len(ops)-1]icase -:sign -ops[len(ops)-1]icase (:ops append(ops, sign)icase ):ops ops[:len(o…

Knife4j的原理及應用詳解(三)

本系列文章簡介&#xff1a; 在當今快速發展的軟件開發領域&#xff0c;API&#xff08;Application Programming Interface&#xff0c;應用程序編程接口&#xff09;作為不同軟件應用之間通信的橋梁&#xff0c;其重要性日益凸顯。隨著微服務架構的興起&#xff0c;API的數量…

價值投資者什么時候賣出股票?

經常有人說&#xff0c;會買的只是徒弟&#xff0c;會賣的才是師傅。 在閱讀《戰勝華爾街》的過程中&#xff0c;也多次感受到林奇先生的賣出邏輯&#xff0c;當股票的價格充分體現了公司的價值的時候&#xff0c;就是該賣出股票的時候。但這只是理論上的&#xff0c;從林奇先…

數據中臺指標管理系統

您所描述的是一個數據中臺指標管理系統&#xff0c;它基于Spring Cloud技術棧構建。數據中臺是企業數據管理和應用的中心平臺&#xff0c;它整合了企業內外部的數據資源&#xff0c;提供數據服務和數據管理能力。以下是您提到的各個模塊的簡要概述&#xff1a; 1. **首頁**&am…

JSP WEB開發(四) MVC模式

MVC模式介紹 MVC&#xff08;Model-View-Controller&#xff09;是一種軟件設計模式&#xff0c;最早出現在Smalltalk語言中&#xff0c;后來在Java中得到廣泛應用&#xff0c;并被Sun公司推薦為Java EE平臺的設計模式。它把應用程序分成了三個核心模塊&#xff1a;模型層、視…

2024年有多少程序員轉行了?

疫情后大環境下行&#xff0c;各行各業的就業情況都是一言難盡。互聯網行業更是極不穩定&#xff0c;頻頻爆出裁員的消息。大家都說2024年程序員的就業很難&#xff0c;都很焦慮。 在許多人眼里&#xff0c;程序員可能是一群背著電腦、進入高大上寫字樓的職業&#xff0c;他們…

SVN 80道面試題及參考答案(2萬字長文)

目錄 解釋SVN的全稱和主要功能。 SVN與CVS相比,有哪些主要改進? 描述SVN的工作流程。 什么是版本庫(repository)?它存儲了什么? 解釋工作副本(working copy)的概念。 SVN如何處理文件的版本控制? SVN中的“commit”是什么意思? 解釋“update”操作的作用。 如何…

Datawhale AI 夏令營 機器學習挑戰賽

一、賽事背景 在當今科技日新月異的時代&#xff0c;人工智能&#xff08;AI&#xff09;技術正以前所未有的深度和廣度滲透到科研領域&#xff0c;特別是在化學及藥物研發中展現出了巨大潛力。精準預測分子性質有助于高效篩選出具有優異性能的候選藥物。以PROTACs為例&#x…

Hi3861 OpenHarmony嵌入式應用入門--MQTT

MQTT 是機器對機器(M2M)/物聯網(IoT)連接協議。它被設計為一個極其輕量級的發布/訂閱消息傳輸 協議。對于需要較小代碼占用空間和/或網絡帶寬非常寶貴的遠程連接非常有用&#xff0c;是專為受限設備和低帶寬、 高延遲或不可靠的網絡而設計。這些原則也使該協議成為新興的“機器…

AutoMQ 生態集成 Kafdrop-ui

Kafdrop [1] 是一個為 Kafka 設計的簡潔、直觀且功能強大的Web UI 工具。它允許開發者和管理員輕松地查看和管理 Kafka 集群的關鍵元數據&#xff0c;包括主題、分區、消費者組以及他們的偏移量等。通過提供一個用戶友好的界面&#xff0c;Kafdrop 大大簡化了 Kafka 集群的監控…

量產工具一一UI系統(四)

目錄 前言 一、按鈕數據結構抽象 1.ui.h 二、按鍵處理 1.button.c 2.disp_manager.c 3.disp_manager.h 三、單元測試 1.ui_test.c 2.上機測試 前言 前面我們實現了顯示系統框架&#xff0c;輸入系統框架和文字系統框架&#xff0c;鏈接&#xff1a; 量產工具一一顯…

Redis 底層數據結構

? 簡單動態字符串 ? 鏈表 ? 字典 ? 跳躍表 ? 整數集合 ? 壓縮列表 ? 對象 SDS 增加了len和free屬性&#xff0c;記錄buf數組的使用空間和剩余空間。好處:strken函數直接讀取len值&#xff0c;時間復雜度是O(1)&#xff1b;預分配buf長度&#xf…

集控中心操作臺材質選擇如何選擇

作為集控中心的核心組成部分&#xff0c;操作臺不僅承載著各種設備和工具&#xff0c;更是工作人員進行監控、操作和管理的重要平臺。因此&#xff0c;選擇適合的集控中心操作臺材質顯得尤為重要。 一、材質選擇的考量因素 在選擇集控中心操作臺材質時&#xff0c;我們需要綜合…

SpringCloud跨微服務的遠程調用,如何發起網絡請求,RestTemplate

在我們的業務流程之中不一定都會是自己模塊查詢自己模塊的信息&#xff0c;有些時候就需要去結合其他模塊的信息來進行一些查詢完成相應的業務流程&#xff0c;但是在SpringCloud每個模塊都相對獨立&#xff0c;數據庫也有數據隔離。所以當我們需要其他微服務模塊的信息的時候&…

什么是SpringCloud Stream?

Spring Cloud Stream 是一個構建消息驅動微服務的框架&#xff0c;其基于Spring Boot來開發&#xff0c;并使用Spring Integration來連接消息代理中間件。該項目的目標是提供一套用于開發消息驅動應用的通用模型&#xff0c;并定義了用于發送和接收消息的綁定器&#xff08;Bin…