【深度學習筆記】計算機視覺——微調

微調

前面的一些章節介紹了如何在只有6萬張圖像的Fashion-MNIST訓練數據集上訓練模型。
我們還描述了學術界當下使用最廣泛的大規模圖像數據集ImageNet,它有超過1000萬的圖像和1000類的物體。
然而,我們平常接觸到的數據集的規模通常在這兩者之間。

假如我們想識別圖片中不同類型的椅子,然后向用戶推薦購買鏈接。
一種可能的方法是首先識別100把普通椅子,為每把椅子拍攝1000張不同角度的圖像,然后在收集的圖像數據集上訓練一個分類模型。
盡管這個椅子數據集可能大于Fashion-MNIST數據集,但實例數量仍然不到ImageNet中的十分之一。
適合ImageNet的復雜模型可能會在這個椅子數據集上過擬合。
此外,由于訓練樣本數量有限,訓練模型的準確性可能無法滿足實際要求。

為了解決上述問題,一個顯而易見的解決方案是收集更多的數據。
但是,收集和標記數據可能需要大量的時間和金錢。
例如,為了收集ImageNet數據集,研究人員花費了數百萬美元的研究資金。
盡管目前的數據收集成本已大幅降低,但這一成本仍不能忽視。

另一種解決方案是應用遷移學習(transfer learning)將從源數據集學到的知識遷移到目標數據集
例如,盡管ImageNet數據集中的大多數圖像與椅子無關,但在此數據集上訓練的模型可能會提取更通用的圖像特征,這有助于識別邊緣、紋理、形狀和對象組合。
這些類似的特征也可能有效地識別椅子。

步驟

本節將介紹遷移學習中的常見技巧:微調(fine-tuning)。如 fig_finetune所示,微調包括以下四個步驟。

  1. 在源數據集(例如ImageNet數據集)上預訓練神經網絡模型,即源模型
  2. 創建一個新的神經網絡模型,即目標模型。這將復制源模型上的所有模型設計及其參數(輸出層除外)。我們假定這些模型參數包含從源數據集中學到的知識,這些知識也將適用于目標數據集。我們還假設源模型的輸出層與源數據集的標簽密切相關;因此不在目標模型中使用該層。
  3. 向目標模型添加輸出層,其輸出數是目標數據集中的類別數。然后隨機初始化該層的模型參數。
  4. 在目標數據集(如椅子數據集)上訓練目標模型。輸出層將從頭開始進行訓練,而所有其他層的參數將根據源模型的參數進行微調。

在這里插入圖片描述

fig_finetune

當目標數據集比源數據集小得多時,微調有助于提高模型的泛化能力。

熱狗識別

讓我們通過具體案例演示微調:熱狗識別。
我們將在一個小型數據集上微調ResNet模型。該模型已在ImageNet數據集上進行了預訓練。
這個小型數據集包含數千張包含熱狗和不包含熱狗的圖像,我們將使用微調模型來識別圖像中是否包含熱狗。

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

獲取數據集

我們使用的[熱狗數據集來源于網絡]。
該數據集包含1400張熱狗的“正類”圖像,以及包含盡可能多的其他食物的“負類”圖像。
含著兩個類別的1000張圖片用于訓練,其余的則用于測試。

解壓下載的數據集,我們獲得了兩個文件夾hotdog/trainhotdog/test
這兩個文件夾都有hotdog(有熱狗)和not-hotdog(無熱狗)兩個子文件夾,
子文件夾內都包含相應類的圖像。

#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

我們創建兩個實例來分別讀取訓練和測試數據集中的所有圖像文件。

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

下面顯示了前8個正類樣本圖片和最后8張負類樣本圖片。正如所看到的,[圖像的大小和縱橫比各有不同]。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

?
在這里插入圖片描述

?

在訓練期間,我們首先從圖像中裁切隨機大小和隨機長寬比的區域,然后將該區域縮放為 224 × 224 224 \times 224 224×224輸入圖像。
在測試過程中,我們將圖像的高度和寬度都縮放到256像素,然后裁剪中央 224 × 224 224 \times 224 224×224區域作為輸入。
此外,對于RGB(紅、綠和藍)顏色通道,我們分別標準化每個通道。
具體而言,該通道的每個值減去該通道的平均值,然后將結果除以該通道的標準差。

[數據增廣]

# 使用RGB通道的均值和標準差,以標準化每個通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

[定義和初始化模型]

我們使用在ImageNet數據集上預訓練的ResNet-18作為源模型。
在這里,我們指定pretrained=True以自動下載預訓練的模型參數。
如果首次使用此模型,則需要連接互聯網才能下載。

pretrained_net = torchvision.models.resnet18(pretrained=True)

預訓練的源模型實例包含許多特征層和一個輸出層fc
此劃分的主要目的是促進對除輸出層以外所有層的模型參數進行微調。
下面給出了源模型的成員變量fc

pretrained_net.fc
Linear(in_features=512, out_features=1000, bias=True)

在ResNet的全局平均匯聚層后,全連接層轉換為ImageNet數據集的1000個類輸出。
之后,我們構建一個新的神經網絡作為目標模型。
它的定義方式與預訓練源模型的定義方式相同,只是最終層中的輸出數量被設置為目標數據集中的類數(而不是1000個)。

在下面的代碼中,目標模型finetune_net中成員變量features的參數被初始化為源模型相應層的模型參數。
由于模型參數是在ImageNet數據集上預訓練的,并且足夠好,因此通常只需要較小的學習率即可微調這些參數。

成員變量output的參數是隨機初始化的,通常需要更高的學習率才能從頭開始訓練。
假設Trainer實例中的學習率為 η \eta η,我們將成員變量output中參數的學習率設置為 10 η 10\eta 10η

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

[微調模型]

首先,我們定義了一個訓練函數train_fine_tuning,該函數使用微調,因此可以多次調用。

# 如果param_group=True,輸出層中的模型參數將使用十倍的學習率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

我們[使用較小的學習率],通過微調預訓練獲得的模型參數。

train_fine_tuning(finetune_net, 5e-5)
loss 0.220, train acc 0.915, test acc 0.939
999.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

在這里插入圖片描述

[為了進行比較,]我們定義了一個相同的模型,但是將其(所有模型參數初始化為隨機值)。
由于整個模型需要從頭開始訓練,因此我們需要使用更大的學習率。

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
loss 0.374, train acc 0.839, test acc 0.843
1623.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

在這里插入圖片描述

意料之中,微調模型往往表現更好,因為它的初始參數值更有效。

小結

  • 遷移學習將從源數據集中學到的知識遷移到目標數據集,微調是遷移學習的常見技巧。
  • 除輸出層外,目標模型從源模型中復制所有模型設計及其參數,并根據目標數據集對這些參數進行微調。但是,目標模型的輸出層需要從頭開始訓練。
  • 通常,微調參數使用較小的學習率,而從頭開始訓練輸出層可以使用更大的學習率。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/715390.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/715390.shtml
英文地址,請注明出處:http://en.pswp.cn/news/715390.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【計算機是怎么跑起來的】軟件,體驗一次手工匯編

【計算機是怎么跑起來的】軟件,體驗一次手工匯編 二進制機器語言匯編語言操作碼操作數寄存器內存地址和I/O地址參考書:計算機是怎么跑起來的 第三章外設在路上。。。先整理一下本書涉及的理論知識,反正后面做視頻也要重寫QAQ 程序的作用是驅動硬件工作,所以在編寫程序之前必…

【C++庖丁解牛】類與對象

📙 作者簡介 :RO-BERRY 📗 學習方向:致力于C、C、數據結構、TCP/IP、數據庫等等一系列知識 📒 日后方向 : 偏向于CPP開發以及大數據方向,歡迎各位關注,謝謝各位的支持 目錄 1.面向過程和面向對象…

對單例模式的餓漢式、懶漢式的思考

目錄 1 什么是單例模式?1.1 什么是餓漢式?1.2 什么是懶漢式? 2 我對餓漢式的思考3 懶漢式3.1 解決懶漢式的線程安全問題3.1.1 加鎖:synchronized(synchronized修飾靜態方法)3.1.2 對“3.1.1”性能的改進 1 …

環形鏈表詳解(讓你徹底理解環形鏈表)

文章目錄 一.什么是環形鏈表?二.環形鏈表的例題(力扣) 三.環形鏈表的延伸問題 補充 一.什么是環形鏈表? 環形鏈表是一種特殊類型的鏈表數據結構,其最后一個節點的"下一個"指針指向鏈表中的某個節點&#xff…

Python 教學平臺,支持“多班教學”的課程授課方式|ModelWhale 版本更新

龍行龘龘、前程朤朤,ModelWhale 新一輪的版本更新,期待為大家帶來更優質的使用體驗。 本次更新中,ModelWhale 主要進行了以下功能迭代: 新增 課程(包括課件、作業、算力)按班級管理(團隊版? …

springcloud的搭建和封裝,已進行開源,相互學習代碼知識。

springcloud架構的統一父工程,(管理子模塊,管理依賴插件,依賴版本等) abillty:能力服務塊:存放一些非業務相關的微服務,比如網關,身份認證等 exce: 網關中的一些異常信息處理 gatewa…

基于Springboot的人事管理系統 (有報告)。Javaee項目,springboot項目。

演示視頻: 基于Springboot的人事管理系統 (有報告)。Javaee項目,springboot項目。 項目介紹: 采用M(model)V(view)C(controller)三層體系結構&am…

【Git】merge時報錯:refusing to merge unrelated histories

文章目錄 一、問題二、解決辦法1、將feature分支的東西追加到master分支中2、將feature里的東西直接覆蓋到master分支中 一、問題 今天將feature分支合并到master時報錯:refusing to merge unrelated histories(拒絕合并無關歷史) 報錯原因&…

一篇文章速通static關鍵字(JAVA)

目錄 1.原理——內存機制 1.1 修飾對象 1.2 lifecycle生命周期 2. 靜態屬性(類屬性)和實例屬性(對象屬性) 2.1 定義方式 2.2 調用方法 3. 靜態方法和屬性 3.1 在同一個類中 3.2 在不同類中 4.總結(關鍵&#x…

SQLSyntaxEProrException異常產生原因及解決方案

java.sq1.SQLSyntaxEProrException異常產生原因及解決方案 01 異常的發生場景 在我mybatis-plus寫了一個查詢接口后出現的問題 java.sq1.SQLSyntaxEProrException日志報錯的意思是sql語法問題 02 異常的產生及其原因 我最開始又認為是MySQL數據庫表設計的問題&#xff0c…

ROS2從入門到精通:理論與實戰

ROS是什么? 隨著人工智能技術的飛速發展與進步,機器人的智能化已經成為現代機器人發展的終極目標。機器人發展的速度在不斷提升,應用范圍也在不斷拓展,例如自動駕駛、移動機器人、操作機器人、信息機器人等。機器人系統是很多復雜…

外貿福利 PHP源碼 WhatsApp 營銷 - 批量發件人、聊天、機器人、SaaS 搭建

WhatsApp 營銷工具對于外貿人員來說至關重要。隨著全球貿易的不斷發展,WhatsApp已成為了許多國際貿易商之間溝通的首選工具之一。通過利用WhatsApp營銷工具,外貿人員可以輕松地與客戶建立聯系,傳遞產品信息,進行價格談判&#xff…

Revit-二開之東西南北立面FilledRegion的CurveLoop計算-(4)

東西南北FilledRegion的CurveLoop計算 上一篇以東立面視圖為例創建FilledRegion,接下來我們將立面視圖創建FilledRegion的CurveLoop匯總一下。 上圖是對四個立面坐標系間的繪制方便我們計算FilledRegion的CurveLoop。 東立面CurveLoop計算 private CurveLoop GetEastCurveL…

3.1網安學習第三階段第一周回顧(個人學習記錄使用)

本周重點 ①HTML/JavaScript/CSS ②PHP ③正則表達式/文件上傳/文件讀寫 ④AJAX不跳轉提交 ⑤ OOP面向對象編程 本周主要內容 DAY1 HTML/JavaScript/CSS ①HTML 一、基本結構 <HTML> <head> //頭部內容 <title>網頁標題</title> </head&…

內網滲透-DC-9靶機滲透

攻擊機&#xff1a;kali 192.168.236.137 目標機&#xff1a;dc-9 192.168.236.138 一、信息收集 1.使用arp-scan -l和nmap進行主機發現和端口信息收集 nmap -sS -T5 --min-rate 10000 192.168.236.138 -sC -p- 發現22端口被阻塞 2.whatweb收集一下cms指紋信息 what http…

Vue開發實例(七)Axios的安裝與使用

說明&#xff1a; 如果只是在前端&#xff0c;axios常常需要結合mockjs使用&#xff0c;如果是前后端分離&#xff0c;就需要調用對應的接口&#xff0c;獲取參數&#xff0c;傳遞參數&#xff1b;由于此文章只涉及前端&#xff0c;所以我們需要結合mockjs使用&#xff1b;由于…

《熱辣滾燙》:用堅持不懈開啟逆境中的職場出路

"你只活一次&#xff0c;所以被嘲笑也沒有關系&#xff0c;想哭也沒有關系&#xff0c;失敗更沒有關系。" “人生就像一場拳擊賽&#xff0c;你站不起來&#xff0c;就永遠不知道自己有多強” “命運只負責洗牌&#xff0c;出牌的永遠是自己。” 在今年的賀歲檔電影市…

云時代【6】—— 鏡像 與 容器

云時代【6】—— 鏡像 與 容器 四、Docker&#xff08;三&#xff09;鏡像 與 容器1. 鏡像&#xff08;1&#xff09;定義&#xff08;2&#xff09;相關指令&#xff08;3&#xff09;實戰演習鏡像容器基本操作離線遷移鏡像鏡像的壓縮與共享 2. 容器&#xff08;1&#xff09;…

為什么模電這么難學?這是我見過最好的回答

大家好&#xff0c;我是磚一&#xff0c;有很多人抱怨模電難學&#xff0c;被譽為電子信息掛科率最高之一&#xff0c;下面聽我分析一下為啥模電這么難學&#xff1f; 01 理科的抽象思維 在高等教育體系中&#xff0c;模電是涉及半導體方向的第一門工程類課程&#xff0c;是一…

2024年3月5-7日年生物發酵裝備展-環科環保科技

參展企業介紹 山東環科環保科技有限公司,是一家集環保設備的設計、制造、安裝、服務及環境治理工程總承包于一體的企業。 公司長期專注于大氣、水、危固廢三大領域&#xff0c;以科技創造碧水藍天&#xff0c;為客戶提供環保解決方案。 以穩定的產品及服務質量、適用的技術、…