卷積神經網絡遷移學習:原理與實踐指南

引言

? ? ? ? 在深度學習領域,卷積神經網絡(CNN)已經在計算機視覺任務中取得了巨大成功。然而,從頭開始訓練一個高性能的CNN模型需要大量標注數據和計算資源。遷移學習(Transfer Learning)技術為我們提供了一種高效解決方案,它能夠將預訓練模型的知識遷移到新任務中,顯著減少訓練時間和數據需求。本文將全面介紹CNN遷移學習的原理、優勢及實踐方法。

1、內容

? ? ? 遷移學習是指利用已經訓練好的模型,在新的任務上進行微調。遷移學習可以加快模型訓練速度,提高模型性能,并且在數據稀缺的情況下也能很好地工作

2、步驟

1、選擇預訓練的模型和適當的層:通常,我們會選擇在大規模圖像數據集(如ImageNet)上預訓練的模型,如VGG、ResNet等。然后,根據新數據集的特點,選擇需要微調的模型層。對于低級特征的任務(如邊緣檢測),最好使用淺層模型的層,而對于高級特征的任務(如分類),則應選擇更深層次的模型。

2、凍結預訓練模型的參數:保持預訓練模型的權重不變,只訓練新增加的層或者微調一些層,避免因為在數據集中過擬合導致預訓練模型過度擬合。

3、在新數據集上訓練新增加的層:在凍結預訓練模型的參數情況下,訓練新增加的層。這樣,可以使新模型適應新的任務,從而獲得更高的性能。

4、微調預訓練模型的層:在新層上進行訓練后,可以解凍一些已經訓練過的層,并且將它們作為微調的目標。這樣做可以提高模型在新數據集上的性能。

5、評估和測試:在訓練完成之后,使用測試集對模型進行評估。如果模型的性能仍然不夠好,可以嘗試調整超參數或者更改微調層。

Resnet網絡:

原理:

? ? ? ?卷積神經網絡都是通過卷積層和池化層的疊加組成的。 在實際的試驗中發現,隨著卷積層和池化層的疊加,學習效果不會逐漸變好,反而出現2個問題:

1、梯度消失和梯度爆炸

梯度消失:若每一層的誤差梯度小于1,反向傳播時,網絡越深,梯度越趨近于0

梯度爆炸:若每一層的誤差梯度大于1,反向傳播時,網絡越深,梯度越來越大

2、退化問題

? ? ? ?為了解決梯度消失或梯度爆炸問題,論文提出通過數據的預處理以及在網絡中使用 BN(Batch Normalization)層來解決。

? ? ? 為了解決深層網絡中的退化問題,可以人為地讓神經網絡某些層跳過下一層神經元的連接,隔層相連,弱化每層之間的強聯系。這種神經網絡被稱為 殘差網絡 (ResNets)

1、18層resnet結構:

2、BN(Batch Normalization)

實例

1、導入相關的庫

import torch
from torch.utils.data import DataLoader,Dataset  #數據包管理工具,打包數據,
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import numpy as np

2、調取模型并凍結參數

#不再需要自己來搭建模型了。預訓練的文件也加載進去了。
# 將resnet18模型遷移到食物分類項目中.#殘差網絡是固定的網絡結構,不需要你自己來類定義了。
resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#weights=models.ResNet18_Weights.DEFAULT表示使用在 ImageNet 數據集上預先訓練好的權重
for param in resnet_model.parameters():print(param)param.requires_grad=False      #凍結
#模型所有參數(即權重和偏差)的requires_grad屬性設置為False,從而凍結所有模型參數,

詳細說明:

  1. models.resnet18()加載ResNet18架構

  2. weights=models.ResNet18_Weights.DEFAULT指定使用官方預訓練權重

  3. 遍歷所有參數并凍結

3、對網絡模型進行微調

in_features=resnet_model.fc.in_features    #獲取模型原輸入的特征個數
resnet_model.fc=nn.Linear(in_features,20)  #創建一個全連接層

4、保存需要訓練的參數

params_to_update=[]    #保存需要訓練的參數,僅僅包含全連接層的參數
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

5、數據預處理

data_transforms={
'train':
transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 隨機旋轉,-45到45度之間隨機選transforms.CenterCrop(224),  # 從中心開始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),  # 隨機水平翻轉 選擇一個概率概率transforms.RandomVerticalFlip(p=0.5),  # 隨機垂直翻轉# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),  # 概率轉換成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}

數據預處理包括:

  • 訓練集使用多種數據增強(隨機旋轉、水平翻轉等)

  • 驗證集只進行簡單的resize和歸一化

  • 歸一化參數使用ImageNet的均值和標準差

6、自定義數據集的類

class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image=Image.open(self.imgs[idx])if self.transform:image=self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label

這個自定義Dataset類:

  1. 從文本文件讀取圖像路徑和標簽

  2. 實現__len____getitem__方法,供DataLoader使用

  3. 應用指定的transform處理圖像

7、數據加載器準備

# 創建訓練集和測試集實例
training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])# 創建數據加載器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

數據加載器提供:

  • 批量加載功能(batch_size=64)

  • 訓練數據隨機打亂(shuffle=True)

  • 多線程數據預讀取

8、訓練和測試流程

def train(dataloader,model,loss_fn,optimizer):model.train()   #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w。在訓練過程中,w會被修改的
#pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train()和 model.eval()。
#一般用法是:在訓練開始之前寫上model.trian(),在測試時寫上 model.eval()for X,y in dataloader:       #其中batch為每一個數據的編號X,y=X.to(device),y.to(device)    #把訓練數據集和標簽傳入cpu或GPUpred=model.forward(X)    #.forward可以被省略,父類中已經對次功能進行了設置。自動初始化loss=loss_fn(pred,y)     #通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()    #梯度值清零loss.backward()          #反向傳播計算得到每個參數的梯度值woptimizer.step()         #根據梯度更新網絡w參數best_acc=0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size

9、循環訓練

# 定義損失函數和優化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵損失
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam優化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 學習率調度# 訓練10個epoch
epoch = 10
for i in range(epoch):print(f"Epoch {i + 1}")train(train_dataloader, model, loss_fn, optimizer)  # 訓練scheduler.step()  # 更新學習率test(test_dataloader, model, loss_fn)  # 測試print('Best accuracy:', best_acc)  # 打印最佳準確率

主循環流程:

  1. 定義損失函數和優化器

  2. 設置學習率調度器(每5個epoch學習率減半)

  3. 進行10輪訓練和測試

  4. 打印最終最佳準確率

結果展示:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/78467.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/78467.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/78467.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

圖論---樸素Prim(稠密圖)

O( n ^2 ) 題目通常會提示數據范圍&#xff1a; 若 V ≤ 500&#xff0c;兩種方法均可&#xff08;樸素Prim更穩&#xff09;。 若 V ≤ 1e5&#xff0c;必須用優先隊列Prim vector 存圖。 // 最小生成樹 —樸素Prim #include<cstring> #include<iostream> #i…

Spring-Cache替換Keys為Scan—負優化?

背景 使用ORM工具是往往會配合緩存框架實現三級緩存提高查詢效率&#xff0c;spring-cache配合redis是非常常規的實現方案&#xff0c;如未做特殊配置&#xff0c;CacheEvict(allEntries true) 的批量驅逐方式&#xff0c;默認使用keys的方式查詢歷史緩存列表而后delete&…

【N8N】Docker Desktop + WSL 安裝過程(Docker Desktop - WSL update Failed解決方法)

背景說明&#xff1a; 因為要用n8n&#xff0c;官網推薦這個就下載了&#xff0c;然后又是一堆卡的安裝問題記錄過程。 1. 下載安裝包 直接去官網Get Docker | Docker Docs下載 下載的是第一個windows - x86_64. &#xff08;*下面那個beta的感覺是測試版&#xff09; PS&am…

RT Thread 發生異常時打印輸出cpu寄存器信息和棧數據

打印輸出發生hardfault時,當前棧十六進制數據和cpu寄存器信息 在發生 HardFault 時,打印當前棧的十六進制數據和 CPU 寄存器信息是非常重要的調試手段。以下是如何實現這一功能的具體步驟和示例代碼。 1. 實現 HardFault 處理函數 我們需要在 HardFault 中捕獲異常上下文,…

【安裝neo4j-5.26.5社區版 完整過程】

1. 安裝java 下載 JDK21-windows官網地址 配置環境變量 在底下的系統變量中新建系統變量&#xff0c;變量名為JAVA_HOME21&#xff0c;變量值為JDK文件夾路徑&#xff0c;默認為&#xff1a; C:\Program Files\Java\jdk-21然后在用戶變量的Path中&#xff0c;添加下面兩個&am…

android jatpack Compose 多數據源依賴處理:從狀態管理到精準更新的架構設計

Android Compose 多接口數據依賴管理&#xff1a;ViewModel 狀態共享最佳實踐 &#x1f4cc; 問題背景 在 Jetpack Compose 開發中&#xff0c;經常遇到以下場景&#xff1a; 頁面由多個獨立接口數據組成&#xff08;如 Part1、Part2&#xff09;Part2 的某些 UI 需要依賴 P…

面試之消息隊列

消息隊列場景 什么是消息隊列&#xff1f; 消息隊列是一個使用隊列來通信的組件&#xff0c;它的本質就是個轉發器&#xff0c;包含發消息、存消息、消費消息。 消息隊列怎么選型&#xff1f; 特性ActiveMQRabbitMQRocketMQKafka單機吞吐量萬級萬級10萬級10萬級時效性毫秒級…

GStreamer 簡明教程(十一):插件開發,以一個音頻生成(Audio Source)插件為例

系列文章目錄 GStreamer 簡明教程&#xff08;一&#xff09;&#xff1a;環境搭建&#xff0c;運行 Basic Tutorial 1 Hello world! GStreamer 簡明教程&#xff08;二&#xff09;&#xff1a;基本概念介紹&#xff0c;Element 和 Pipeline GStreamer 簡明教程&#xff08;三…

Linux kernel signal原理(下)- aarch64架構sigreturn流程

一、前言 在上篇中寫到了linux中signal的處理流程&#xff0c;在do_signal信號處理的流程最后&#xff0c;會通過sigreturn再次回到線程現場&#xff0c;上篇文章中介紹了在X86_64架構下的實現&#xff0c;本篇中介紹下在aarch64架構下的實現原理。 二、sigaction系統調用 #i…

華為OD機試真題——簡易內存池(2025A卷:200分)Java/python/JavaScript/C++/C/GO最佳實現

2025 A卷 200分 題型 本文涵蓋詳細的問題分析、解題思路、代碼實現、代碼詳解、測試用例以及綜合分析&#xff1b; 并提供Java、python、JavaScript、C、C語言、GO六種語言的最佳實現方式&#xff01; 本文收錄于專欄&#xff1a;《2025華為OD真題目錄全流程解析/備考攻略/經驗…

騰訊一面面經:總結一下

1. Java 中的 和 equals 有什么區別&#xff1f;比較對象時使用哪一個 1. 操作符&#xff1a; 用于比較對象的內存地址&#xff08;引用是否相同&#xff09;。 對于基本數據類型、 比較的是值。&#xff08;8種基本數據類型&#xff09;對于引用數據類型、 比較的是兩個引…

計算機網絡中的DHCP是什么呀? 詳情解答

目錄 DHCP 是什么&#xff1f; DHCP 的工作原理 主要功能 DHCP 與網絡安全的關系 1. 正面作用 2. 潛在安全風險 DHCP 的已知漏洞 1. 協議設計缺陷 2. 軟件實現漏洞 3. 配置錯誤導致的漏洞 4. 已知漏洞總結 舉例說明 DHCP 與網絡安全 如何提升 DHCP 安全性 總結 D…

2025 年導游證報考條件新政策解讀與應對策略

2025 年導游證報考政策有了不少新變化&#xff0c;這些變化會對報考者產生哪些影響&#xff1f;我們又該如何應對&#xff1f;下面就為大家詳細解讀新政策&#xff0c;并提供實用的應對策略。 最引人注目的變化當屬中職旅游類專業學生的報考政策。以往&#xff0c;中專學歷報考…

【物聯網】基于LORA組網的遠程環境監測系統設計(ThingsCloud云平臺版)

演示視頻: 基于LORA組網的遠程環境監測系統設計(ThingsCloud云平臺版) 前言:本設計是基于ThingsCloud云平臺版,還有另外一個版本是基于機智云平臺版本,兩個設計只是云平臺和手機APP的區別,其他功能都一樣。如下鏈接: 【物聯網】基于LORA組網的遠程環境監測系統設計(機…

SQL 函數進行左邊自動補位fnPadLeft和FORMAT

目錄 1.問題 2.解決 方式1 方式2 3.結果 1.問題 例如在SQL存儲過程中&#xff0c;將1 或10 或 100 長度不足的時候&#xff0c;自動補足長度。 例如 1 → 001 10→ 010 100→100 2.解決 方式1 SELECT FORMAT (1, 000) AS FormattedNum; SELECT FORMAT(12, 000) AS Form…

Nacos簡介—2.Nacos的原理簡介

大綱 1.Nacos集群模式的數據寫入存儲與讀取問題 2.基于Distro協議在啟動后的運行規則 3.基于Distro協議在處理服務實例注冊時的寫路由 4.由于寫路由造成的數據分片以及隨機讀問題 5.寫路由 數據分區 讀路由的CP方案分析 6.基于Distro協議的定時同步機制 7.基于Distro協…

中電金信聯合阿里云推出智能陪練Agent

在金融業加速數智化轉型的今天&#xff0c;提升服務效率與改善用戶體驗已成為行業升級的核心方向。面對這一趨勢&#xff0c;智能體與智能陪練的結合應用&#xff0c;正幫助金融機構突破傳統業務模式&#xff0c;開拓更具競爭力的創新機遇。 在近日召開的阿里云AI勢能大會期間&…

十分鐘恢復服務器攻擊——群聯AI云防護系統實戰

場景描述 服務器遭遇大規模DDoS攻擊&#xff0c;導致服務不可用。通過群聯AI云防護系統的分布式節點和智能調度功能&#xff0c;快速切換流量至安全節點&#xff0c;清洗惡意流量&#xff0c;10分鐘內恢復業務。 技術實現步驟 1. 啟用智能調度API觸發節點切換 群聯系統提供RE…

LLM量化技術全景:GPTQ、QAT、AWQ、GGUF與GGML

01 引言 本文介紹的是在 LLM 討論中經常聽到的各種量化技術。本文的目的是提供一步一步的解釋和代碼&#xff0c;讓大家可以自己使用這些技術來壓縮模型。 閑話少說&#xff0c;我們來研究一下吧&#xff01; 02 Quantization 量化是指將高精度數字轉換為低精度數字。低精…

IP的基礎知識以及相關機制

IP地址 1.IP地址的概念 IP地址是分配給連接到互聯網或局域網中的每一個設備的唯一標識符 也就是說IP地址是你設備在網絡中的定位~ 2.IP版本~ IP版本分為IPv4和IPv6&#xff0c;目前我們最常用的還是IPv4~~但是IPv4有個缺點就是地址到現在為止&#xff0c;已經接近枯竭~~&…