深度學習-計算機視覺-微調 Fine-tune

1. 遷移學習

遷移學習(transfer learning)是一種機器學習方法,通過將源數據集(如ImageNet)上訓練得到的模型知識遷移到目標數據集(如特定場景的椅子識別任務)。這種方法的核心在于利用預訓練模型已學習到的通用特征,提升目標任務的性能,尤其是在數據量有限的情況下。

ImageNet等大規模數據集雖然包含多樣化的圖像類別(如動物、植物、人造物體等),但在此類數據集上訓練的模型能夠自動學習低層次和高層次的通用視覺特征。這些特征包括但不限于:

  • 邊緣檢測:識別圖像中的輪廓和邊界。
  • 紋理分析:捕捉不同材質的表面特性。
  • 形狀建模:提取物體的幾何結構。
  • 對象組合關系:理解多物體間的空間和語義關聯。

盡管ImageNet可能不直接包含大量椅子圖像,但模型學習到的上述特征仍然能夠有效支持椅子識別任務。例如,椅子的邊緣結構與動物輪廓可能共享相似的檢測模式,而木質或布藝椅子的紋理特征可能與其他物體(如家具或服飾)的紋理分析機制高度相關。

通過遷移學習,目標任務無需從零開始訓練模型,從而顯著減少計算資源和時間成本。此外:

  • 在目標數據集較小的情況下,遷移學習能夠避免過擬合問題
  • 預訓練模型的高層次特征提取能力可提升目標任務的泛化性能
  • 適用于跨領域應用,如醫學影像分析、自動駕駛等場景

2. 微調

遷移學習中的常見技巧:微調(fine-tuning)。

微調包括以下四個步驟。

  1. 在源數據集(例如ImageNet數據集)上預訓練神經網絡模型,即源模型

  2. 創建一個新的神經網絡模型,即目標模型。這將復制源模型上的所有模型設計及其參數(輸出層除外)。我們假定這些模型參數包含從源數據集中學到的知識,這些知識也將適用于目標數據集。我們還假設源模型的輸出層與源數據集的標簽密切相關;因此不在目標模型中使用該層。

  3. 向目標模型添加輸出層,其輸出數是目標數據集中的類別數。然后隨機初始化該層的模型參數。

  4. 在目標數據集(如椅子數據集)上訓練目標模型。輸出層將從頭開始進行訓練,而所有其他層的參數將根據源模型的參數進行微調

微調中的 權重初始化

微調

  • 遷移學習將從源數據集中學到的知識遷移到目標數據集,微調是遷移學習的常見技巧

  • 除輸出層外,目標模型從源模型中復制所有模型設計及其參數,并根據目標數據集對這些參數進行微調。

  • 但是,目標模型的輸出層需要從頭開始訓練。

  • 通常,微調參數使用較小的學習率,而從頭開始訓練輸出層可以使用更大的學習率。

3. 代碼實現

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

下面顯示了前8個正類樣本圖片和最后8張負類樣本圖片。

正如所看到的,圖像的大小和縱橫比各有不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

數據增廣:

對于RGB(紅、綠和藍)顏色通道,我們分別標準化每個通道。 具體而言,該通道的每個值減去該通道的平均值,然后將結果除以該通道的標準差。

# 使用RGB通道的均值和標準差,以標準化每個通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

查看源模型:

我們使用在ImageNet數據集上預訓練的ResNet-18作為源模型。

在這里,我們指定pretrained=True以自動下載預訓練的模型參數。 如果首次使用此模型,則需要連接互聯網才能下載。

預訓練的源模型實例包含許多特征層和一個輸出層fc。 此劃分的主要目的是促進對除輸出層以外所有層的模型參數進行微調。

下面給出了源模型的成員變量fc

pretrained_net = torchvision.models.resnet18(pretrained=True)
pretrained_net.fc輸出:
Linear(in_features=512, out_features=1000, bias=True)

定義目標模型finetune_net:【換輸出層

  1. finetune_net = torchvision.models.resnet18(pretrained=True) 載入 在 ImageNet 上預訓練好的 ResNet-18 網絡(權重已固定,特征提取器已具備通用視覺能力)。

  2. finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) 把 ResNet-18 最后的 1000 類全連接層 替換成 只有 2 個輸出的新層(二分類任務)。

    1. in_features 保留原模型提取到的特征維度(512)。

    2. 輸出維度改為 2,對應新任務的類別數。

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

首先,我們定義了一個訓練函數train_fine_tuning,該函數使用微調,因此可以多次調用。

其中特征提取層學習率低,最后一層輸出層學習率高。

# 如果param_group=True,輸出層中的模型參數將使用十倍的學習率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

我們使用較小的學習率,通過微調預訓練獲得的模型參數。

train_fine_tuning(finetune_net, 5e-5)輸出:
loss 0.220, train acc 0.915, test acc 0.939
999.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

為了進行比較,我們定義了一個相同的模型,但是將其所有模型參數初始化為隨機值。

由于整個模型需要從頭開始訓練,因此我們需要使用更大的學習率。

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)輸出:
loss 0.374, train acc 0.839, test acc 0.843
1623.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/93789.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/93789.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/93789.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

STL庫——string(類函數學習)

? ? ? ? ? づ?ど 🎉 歡迎點贊支持🎉 個人主頁:勵志不掉頭發的內向程序員; 專欄主頁:C語言; 文章目錄 前言 一、STL簡介 二、string類的優點 三、標準庫中的string類 四、string的成員函數 4.1、構造…

登上Nature!清華大學光學神經網絡研究突破

2025深度學習發論文&模型漲點之——光學神經網絡光學神經網絡的基本原理是利用光的傳播、干涉、衍射等特性來實現神經網絡中的信息處理和計算。在傳統神經網絡中,信息以電信號的形式在電子元件之間傳輸和處理,而在光學神經網絡中,信息則以…

【java】對word文件設置只讀權限

文件流輸出時 template.getXWPFDocument().enforceCommentsProtection(); 文件輸出時 //打開己創建的word文檔 XWPFDocument document new XWPFDocument(new FileInputStream("output.docx")); //設置文檔為只讀 document.enforceReadonlyProtection(); //保存文…

Zookeeper 在 Kafka 中扮演了什么角色?

在 Apache Kafka 的早期架構中,ZooKeeper 扮演了分布式協調服務角色,負責管理和協調整個 Kafka 集群。 盡管新版本的 Kafka 正在逐步移除對 ZooKeeper 的依賴,但在許多現有和較早的系統中,了解 ZooKeeper 的作用仍然非常重要。 Zo…

什么叫做 “可迭代的產品矩陣”?如何落地??

“可迭代的產品矩陣” 不是靜態的產品組合,而是圍繞用戶需求與商業目標構建的動態生態。它以核心產品為根基,通過多維度延伸形成產品網絡,同時具備根據市場反饋持續優化的彈性,讓產品體系既能覆蓋用戶全生命周期需求,又…

Nginx代理配置詳解:正向代理與反向代理完全指南

系列文章索引: 第一篇:《Nginx入門與安裝詳解:從零開始搭建高性能Web服務器》第二篇:《Nginx基礎配置詳解:nginx.conf核心配置與虛擬主機實戰》第三篇:《Nginx代理配置詳解:正向代理與反向代理…

Vue3 Element-plus 封裝Select下拉復選框選擇器

廢話不多說&#xff0c;樣式如下&#xff0c;代碼如下&#xff0c;需要自取<template><el-selectv-model"selectValue"class"checkbox-select"multiple:placeholder"placeholder":style"{ width: width }"change"change…

jenkins 自動部署

一、win10 環境安裝&#xff1a; 1、jdk 下載安裝&#xff1a;Index of openjdk-local 2、配置環境變量&#xff1a; 3、jenkins 下載&#xff1a;Download and deploy 下載后的結果&#xff1a;jenkins.war 4、jenkins 啟動&#xff1a; 5、創建管理員用戶 admin 登錄系統…

2020 GPT3 原文 Language Models are Few-Shot Learners 精選注解

本文為個人閱讀GPT3&#xff0c;部分內容注解&#xff0c;由于GPT3原文篇幅較長&#xff0c;且GPT3無有效開源信息 這里就不再一一粘貼&#xff0c;僅對原文部分內容做注解&#xff0c;僅供參考 詳情參考原文鏈接 原文鏈接&#xff1a;https://arxiv.org/pdf/2005.14165 語言模…

設計模式筆記_行為型_迭代器模式

1. 迭代器模式介紹迭代器模式&#xff08;Iterator Pattern&#xff09;是一種行為設計模式&#xff0c;旨在提供一種方法順序訪問一個聚合對象中的各個元素&#xff0c;而又不需要暴露該對象的內部表示。這個模式的主要目的是將集合的遍歷與集合本身分離&#xff0c;使得用戶可…

【Part 4 未來趨勢與技術展望】第一節|技術上的抉擇:三維實時渲染與VR全景視頻的共生

《VR 360全景視頻開發》專欄 將帶你深入探索從全景視頻制作到Unity眼鏡端應用開發的全流程技術。專欄內容涵蓋安卓原生VR播放器開發、Unity VR視頻渲染與手勢交互、360全景視頻制作與優化&#xff0c;以及高分辨率視頻性能優化等實戰技巧。 &#x1f4dd; 希望通過這個專欄&am…

mac查看nginx安裝位置 mac nginx啟動、重啟、關閉

安裝工具&#xff1a;homebrew步驟&#xff1a;1、打開終端&#xff0c;習慣性命令&#xff1a;brew update //結果&#xff1a;Already up-to-date.2、終端繼續執行命令&#xff1a;brew search nginx //查詢要安裝的軟件是否存在3、執行命令&#xff1a;brew info nginx4. …

網絡通信的基本概念與設備

目錄 一、互聯網 二、JAVA跨平臺與C/C的原理 1、JAVA跨平臺的原理 2、C/C跨平臺的原理 三、網絡互連模型 四、客戶端與服務器 五、計算機之間的通信基礎 1、IP地址與MAC地址 2、ARP與ICMP對比 ①ARP協議&#xff08;地址解析協議&#xff09; ②ICMP協議&#xff08…

云原生俱樂部-k8s知識點歸納(1)

這篇文章主要是講講k8s中的知識點歸納&#xff0c;以幫助理解。雖然平時也做筆記和總結&#xff0c;但是就將內容復制過來不太好&#xff0c;而且我比較喜歡打字。因此知識點歸納總結還是以敘述的口吻來說說&#xff0c;并結合我的理解加以論述。k8s和docker首先講一講docker和…

基于Node.js+Express的電商管理平臺的設計與實現/基于vue的網上購物商城的設計與實現/基于Node.js+Express的在線銷售系統

基于Node.jsExpress的電商管理平臺的設計與實現/基于vue的網上購物商城的設計與實現/基于Node.jsExpress的在線銷售系統

Git 對象存儲:理解底層原理,實現高效排錯與存儲優化

### 探秘 Git 對象存儲&#xff1a;底層原理與優化實踐#### 一、Git 對象存儲的底層原理 Git 采用**內容尋址文件系統**&#xff0c;核心機制如下&#xff1a; 1. **對象類型與存儲** - **Blob 對象**&#xff1a;存儲文件內容&#xff0c;通過 git hash-object 生成唯一 SHA-…

【2025CVPR-目標檢測方向】RaCFormer:通過基于查詢的雷達-相機融合實現高質量的 3D 目標檢測

1. 研究背景與動機? ?問題?:現有雷達-相機融合方法依賴BEV特征融合,但相機圖像到BEV的轉換因深度估計不準確導致特征錯位;雷達BEV特征稀疏,相機BEV特征因深度誤差存在畸變。 ?核心思路?:提出跨視角查詢融合框架,通過對象查詢(object queries)同時采樣圖像視角(原…

【每日一題】Day 7

560.和為K的子數組 題目&#xff1a; 給你一個整數數組 nums 和一個整數 k &#xff0c;請你統計并返回該數組中和為 k 的子數組的個數 。 子數組是數組中元素的連續非空序列。 示例 1&#xff1a; 輸入&#xff1a;nums [1,1,1], k 2 輸出&#xff1a;2 示例 2&#x…

3ds MAX文件/貼圖名稱亂碼?6大根源及解決方案

在3ds MAX渲染階段&#xff0c;文件或貼圖名稱亂碼導致渲染失敗&#xff0c;是困擾眾多用戶的常見難題。其背后原因多樣&#xff0c;精準定位方能高效解決&#xff1a;亂碼核心根源剖析字符編碼沖突 (最常見)非ASCII字符風險&#xff1a; 文件路徑或名稱包含中文、日文、韓文等…

鏈路聚合路由器OpenMPTCProuter源碼編譯與運行

0.前言 前面寫了兩篇關于MPTCP的文章&#xff1a; 《鏈路聚合技術——多路徑傳輸Multipath TCP(MPTCP)快速實踐》《使用MPTCPBBR進行數據傳輸&#xff0c;讓網絡又快又穩》 對MPTCP有了基本的了解與實踐&#xff0c;并在虛擬的網絡拓撲中實現了鏈路帶寬的疊加。 1.OpenMPTC…