深度學習之第八課遷移學習(殘差網絡ResNet)

目錄

簡介

一、遷移學習

1.什么是遷移學習

2. 遷移學習的步驟

二、殘差網絡ResNet

1.了解ResNet

2.ResNet網絡---殘差結構

三、代碼分析

1. 導入必要的庫

2. 模型準備(遷移學習)

3. 數據預處理

4. 自定義數據集類

5. 數據加載器

6. 設備配置

7. 訓練函數

8. 測試函數

9. 訓練配置和執行

整體流程總結


簡介

????????經過長久的卷積神經網絡的學習、我們學習了如何提高模型的準確率,但是最終我們的準確率還是沒達到百分之八十。原因是因為我們本身模型的局限,面對現有很多成熟的模型,它們有很好的效果,都是經過多次訓練選取了最佳的參數,那我們能不能去使用哪些大佬的模型呢?

????????答案是可以的,這就使用到遷移學習的知識。

深度學習之第五課卷積神經網絡 (CNN)如何訓練自己的數據集(食物分類)

深度學習之第六課卷積神經網絡 (CNN)如何保存和使用最優模型

深度學習之第七課卷積神經網絡 (CNN)調整學習率

一、遷移學習

1.什么是遷移學習

????????遷移學習是指利用已經訓練好的模型,在新的任務上進行微調。遷移學習可以加快模型訓練速度,提高模型性能,并且在數據稀缺的情況下也能很好地工作。

2. 遷移學習的步驟

????????1、選擇預訓練的模型和適當的層:通常,我們會選擇在大規模圖像數據集(如ImageNet)上預訓練的模型,如VGG、ResNet等。然后,根據新數據集的特點,選擇需要微調的模型層。對于低級特征的任務(如邊緣檢測),最好使用淺層模型的層,而對于高級特征的任務(如分類),則應選擇更深層次的模型。

????????2、凍結預訓練模型的參數:保持預訓練模型的權重不變,只訓練新增加的層或者微調一些層,避免因為在數據集中過擬合導致預訓練模型過度擬合。

????????3、在新數據集上訓練新增加的層:在凍結預訓練模型的參數情況下,訓練新增加的層。這樣,可以使新模型適應新的任務,從而獲得更高的性能。

????????4、微調預訓練模型的層:在新層上進行訓練后,可以解凍一些已經訓練過的層,并且將它們作為微調的目標。這樣做可以提高模型在新數據集上的性能。

????????5、評估和測試:在訓練完成之后,使用測試集對模型進行評估。如果模型的性能仍然不夠好,可以嘗試調整超參數或者更改微調層。

太多概念,我們直接使用殘差網絡進行遷移學習。

二、殘差網絡ResNet

1.了解ResNet

????????ResNet 網絡是在 2015年 由微軟實驗室中的何凱明等幾位大神提出,斬獲當年ImageNet競賽中分類任務第一名,目標檢測第一名。獲得COCO數據集中目標檢測第一名,圖像分割第一名。

傳統卷積神經網絡存在的問題?

卷積神經網絡都是通過卷積層和池化層的疊加組成的。 在實際的試驗中發現,隨著卷積層和池化層的疊加,學習效果不會逐漸變好,反而出現2個問題:

????????1、梯度消失和梯度爆炸 梯度消失:若每一層的誤差梯度小于1,反向傳播時,網絡越深,梯度越趨近于0 梯度爆炸:若每一層的誤差梯度大于1,反向傳播時,網絡越深,梯度越來越大

????????2、退化問題

如何解決問題?

為了解決梯度消失或梯度爆炸問題,論文提出通過數據的預處理以及在網絡中使用 BN(Batch Normalization)層來解決。 為了解決深層網絡中的退化問題,可以人為地讓神經網絡某些層跳過下一層神經元的連接,隔層相連,弱化每層之間的強聯系。這種神經網絡被稱為 殘差網絡 (ResNets)。

????????????????????????????????????????實線為測試集錯誤率 虛線為訓練集錯誤率

2.ResNet網絡---殘差結構

ResNet的經典網絡結構有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152幾種,其中,ResNet-18和ResNet-34的基本結構相同,屬于相對淺層的網絡,后面3種的基本結構不同于ResNet-18和ResNet-34,屬于更深層的網絡。

不論是多少層的ResNet網絡,它們都有以下共同點:

  • 網絡一共包含5個卷積組,每個卷積組中包含1個或多個基本的卷積計算過程(Conv-> BN->ReLU)
  • 每個卷積組中包含1次下采樣操作,使特征圖大小減半,下采樣通過以下兩種方式實現:
    • 最大池化,步長取2,只用于第2個卷積組(Conv2_x)
    • 卷積,步長取2,用于除第2個卷積組之外的4個卷積組
  • 第1個卷積組只包含1次卷積計算操作,5種典型ResNet結構的第1個卷積組完全相同,卷積核均為7x7, 步長為均2
  • 第2-5個卷積組都包含多個相同的殘差單元,在很多代碼實現上,通常把第2-5個卷積組分別叫做Stage1、Stage2、Stage3、Stage4
  • 首先是第一層卷積使用kernel 7?7,步長為2,padding為3。之后進行BN,ReLU和maxpool。這些構成了第一部分卷積模塊conv1。
  • 然后是四個stage,有些代碼中用make_layer()來生成stage,每個stage中有多個模塊,每個模塊叫做building block,resnet18= [2,2,2,2],就有8個building block。注意到他有兩種模塊BasicBlockBottleneck。resnet18和resnet34用的是BasicBlock,resnet50及以上用的是Bottleneck。無論BasicBlock還是Bottleneck模塊,都用到了殘差連接(shortcut connection)方式:

下圖以ResNet18為例介紹一下它的網絡模型

layer1

????????ResNet18 ,使用的是?BasicBlocklayer1,特點是沒有進行降采樣,卷積層的?stride = 1,不會降采樣。在進行?shortcut?連接時,也沒有經過?downsample?層。

layer2,layer3,layer4

而?layer2layer3layer4?的結構圖如下,每個?layer?包含 2 個?BasicBlock,但是第 1 個?BasicBlock?的第 1 個卷積層的?stride = 2,會進行降采樣。在進行?shortcut?連接時,會經過?downsample?層,進行降采樣和降維

????????residual結構使用了一種shortcut的連接方式,也可理解為捷徑。讓特征矩陣隔層相加,注意F(X)和X形狀要相同,所謂相加是特征矩陣相同位置上的數字進行相加。

????????一個殘差塊有2條路徑 F(x)和 x,F(x) 路徑擬合殘差,可稱之為殘差路徑;?路徑為`identity mapping`恒等映射,可稱之為`shortcut`。圖中的⊕為`element-wise addition`,要求參與運算的F(x)??和?x的尺寸要相同。

其中關鍵技術?Batch Normalization是對每一個卷積后進行標準化

????????Batch Normalization目的:使所有的feature map滿足均值為0,方差為1的分布規律

三、代碼分析

1. 導入必要的庫

import torch
from torch.utils.data import DataLoader,Dataset  # 數據加載相關
from PIL import Image  # 圖像處理
from torchvision import transforms  # 數據預處理
import numpy as np
from torch import nn  # 神經網絡模塊
import torchvision.models as models  # 預訓練模型

2. 模型準備(遷移學習)

這部分是遷移學習的重點,

# 加載預訓練的ResNet-18模型
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)# 凍結所有預訓練參數(遷移學習常用策略)
for param in resnet_model.parameters():print(param)  # 打印參數(實際應用中可刪除)param.requires_grad = False  # 凍結參數,不參與訓練# 獲取原模型最后一層的輸入特征數
in_features = resnet_model.fc.in_features  # ResNet18的fc層輸入是512# 替換最后一層全連接層,輸出類別數為20(根據實際任務調整)
resnet_model.fc = nn.Linear(in_features, 20)# 收集需要更新的參數(只有新替換的全連接層參數)
params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

????????這里采用了遷移學習策略:凍結預訓練模型的大部分參數,只訓練最后一層的分類器,這樣可以加快訓練速度并提高效果。

  • models.resnet18():創建 ResNet-18 網絡結構
  • weights=models.ResNet18_Weights.DEFAULT:使用在 ImageNet 數據集上預訓練好的權重初始化模型
  • 遷移學習的關鍵操作:保留預訓練模型學到的特征提取能力
  • requires_grad = False:告訴 PyTorch 不需要計算這些參數的梯度
  • 原 ResNet-18 用于 1000 類分類,這里替換為 20 類分類
  • 只訓練新替換的全連接層參數,大大減少計算量

3. 數據預處理

data_transforms = {'train': transforms.Compose([  # 訓練集的數據增強transforms.Resize([300, 300]),  # 調整大小transforms.RandomRotation(45),  # 隨機旋轉transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 隨機水平翻轉transforms.RandomVerticalFlip(p=0.5),  # 隨機垂直翻轉transforms.ToTensor(),  # 轉為Tensor# 歸一化,使用ImageNet的均值和標準差transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([  # 驗證集不做數據增強,只做必要處理transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),
}

4. 自定義數據集類

class food_dataset(Dataset):  # 繼承Dataset類def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []  # 存儲圖像路徑self.labels = []  # 存儲標簽self.transform = transform# 從文件中讀取圖像路徑和標簽with open(file_path, 'r') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):  # 返回數據集大小return len(self.imgs)def __getitem__(self, idx):  # 獲取單個樣本image = Image.open(self.imgs[idx])  # 打開圖像if self.transform:  # 應用預處理image = self.transform(image)# 處理標簽,轉為Tensorlabel = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

5. 數據加載器

# 創建訓練集和測試集
train_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['train'])  # 注意這里可能應該用'valid'# 創建數據加載器,用于批量加載數據
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

6. 設備配置

# 自動選擇可用的計算設備(GPU優先)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")# 將模型移動到選定的設備
model = resnet_model.to(device)

7. 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()  # 切換到訓練模式batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)  # 將數據移動到設備# 前向傳播pred = model.forward(X)loss = loss_fn(pred, y)  # 計算損失# 反向傳播和參數更新optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向傳播計算梯度optimizer.step()  # 更新參數# 打印訓練信息loss = loss.item()if batch_size_num % 64 == 0:print(f"loss: {loss:>7f} [number: {batch_size_num}]")batch_size_num += 1

8. 測試函數

best_acc = 0  # 記錄最佳準確率def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 切換到評估模式test_loss, correct = 0, 0with torch.no_grad():  # 關閉梯度計算,節省內存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()# 計算正確預測的數量correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")# 保存最佳模型global best_accif correct > best_acc:best_acc = correcttorch.save(model, 'best3.pt')  # 保存整個模型

9. 訓練配置和執行

# 定義損失函數和優化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵損失,適用于分類任務
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam優化器# 學習率調度器,每10個epoch學習率乘以0.5
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 訓練輪次
epochs = 20
acc_s = []
loss_s = []# 開始訓練
for t in range(epochs):print(f"Epoch {t+1}\n-----------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)scheduler.step()  # 更新學習率
print("Done!")
print(f"最佳的結果:\n Accuracy: {(100*best_acc):>0.1f}%")

整體流程總結

  1. 加載預訓練的 ResNet-18 模型并修改最后一層以適應新任務
  2. 定義數據預處理和增強方法
  3. 創建自定義數據集類來讀取圖像和標簽
  4. 設置訓練設備(GPU 或 CPU)
  5. 定義訓練和測試函數
  6. 配置優化器、損失函數和學習率調度器
  7. 執行多輪訓練,每輪結束后在測試集上評估并保存最佳模型

最后我們都結果可以達到百分之90左右,效果得到很大的提升。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98137.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98137.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98137.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Pinia 兩種寫法全解析:Options Store vs Setup Store(含實踐與場景對比)

目標:把 Pinia 的兩種寫法講透,寫明“怎么寫、怎么用、怎么選、各自優缺點與典型場景”。全文配完整代碼與注意事項,可直接當團隊規范參考。一、背景與準備 適用版本:Vue 3 Pinia 2.x安裝與初始化: # 安裝 npm i pini…

setup函數相關【3】

目錄1.setup函數:1.概述:2.案例分析:2.setup函數的優化:(setup語法糖)優化1:優化2:安裝插件:安裝指令:只對當前項目安裝配置vite.config.ts:代碼編…

如何通過AI進行數據資產梳理

最終產出 數據資產清單 包含所有數據資產的詳細目錄,列出數據集名稱、描述、所有者、格式、存儲位置和元數據。 用途:幫助政府部門清晰了解數據資產分布和狀態。 數據質量報告 數據質量評估結果,記錄準確性、完整性、一致性等問題及改進建議,基于政府認可的數據質量框架(如…

【傳奇開心果系列】Flet框架結合pillow實現的英文文字倒映特效自定義模板特色和實現原理深度解析

Flet框架結合pillow實現的英文文字倒映特效自定義模板特色和實現原理深度解析 一、效果展示截圖 二、使用場景 三、特色說明 四、概括說明 五、依賴文件列表 六、安裝依賴命令 七、 項目結構建議 八、注意事項 九、Flet 文字倒影效果實現原理分析 (一)組件結構與功能 1. 圖像…

2025最新深度學習面試必問100題--理論+框架+原理+實踐 (下篇)

2025最新深度學習面試必問100題–理論框架原理實踐 (下篇) 在上篇中,我們已經深入探討了機器學習基礎、CNN、RNN及其變體,以及模型優化的核心技巧。 在下篇中,我們將把目光投向更遠方,聚焦于當今AI領域最炙手可熱的前沿。我們將深…

原子工程用AC6編譯不過問題

…\Output\atk_h750.axf: Error: L6636E: Pre-processor step failed for ‘…\User\SCRIPT\qspi_code.scf.scf’修改前: #! armcc -E ;#! armclang -E --targetarm-arm-none-eabi -mcpucortex-m7 -xc /* 使用說明 ! armclang -E --targetarm-arm-none-eabi -mcpuco…

Python有哪些經典的常用庫?(第一期)

目錄 1、NumPy (數值計算基礎庫) 核心特點: 應用場景: 代碼示例: 2、Pandas (數據分析處理庫) 應用場景: 代碼示例: 3、Scikit-learn (機器學習庫) 核心特點: 應用場景: 代碼示例&am…

現代 C++ 高性能程序驅動器架構

🧠 現代 C 高性能程序驅動器架構M/PA(多進程)是隔離的“孤島”,M/TA(多線程)是共享的“戰場”,EDSM(事件驅動)是高效的“反應堆”,MDSM(消息驅動&…

投資儲能項目能賺多少錢?小程序幫你測算

為解決電網負荷平衡、提升新能源消納等問題,儲能項目的投資開發越來越多。那么,投資儲能項目到底能賺多少錢?適不適合投資?用“綠蟲零碳助手”3秒鐘精準測算。操作只需四步,簡單易懂:1.快速登錄&#xff1a…

Mac 能夠連Wife,但是不能上網問題解決

請按照以下步驟從最簡單、最可能的原因開始嘗試: 第一步:基礎快速排查 這些步驟能解決大部分臨時性的小故障。 重啟設備:關閉您的 Mac 和路由器,等待一分鐘后再重新打開。這是解決網絡問題最有效的“萬能藥”。檢查其他設備&am…

基于SpringBoot的旅游管理系統的設計與實現(代碼+數據庫+LW)

摘要 本文闡述了一款基于SpringBoot框架的旅游管理系統設計與實現。該系統整合了用戶信息管理、旅游資源展示、訂單處理流程及安全保障機制等核心功能,專為提升旅游行業的服務質量和運營效率而設計。 系統采用前后端分離架構,前端界面設計注重跨設備兼…

Springboot樂家流浪貓管理系統16lxw(程序+源碼+數據庫+調試部署+開發環境)帶論文文檔1萬字以上,文末可獲取,系統界面在最后面。

系統程序文件列表項目功能:領養人,流浪貓,領養申請開題報告內容基于Spring Boot的樂家流浪貓管理系統開題報告一、研究背景與意義隨著城市化進程加速和人口增長,流浪貓問題已成為全球性社會挑戰。據統計,全球每年約有1.5億只無家可歸的寵物&a…

函數定義跳轉之代碼跳轉

相信大家在開發的過程中都有用到函數定義跳轉的功能,在 IDE 中,如果在函數調用的地方停留光標,可能會提示對應的函數定義,在 GitHub 中也是如此,對于一些倉庫來說,我們可以直接查看對應的函數定義了&#x…

探討Xsens在人形機器人研發中的四個核心應用

探索Xsens動作捕捉如何改變人形機器人研發——使機器人能夠從人類運動中學習、更直觀地協作并彌合模擬與現實世界之間的差距。人形機器人技術是當今世界最令人興奮且最復雜的前沿領域之一。研究人員不僅致力于開發能夠像人類一樣行走和行動的機器人,還致力于開發能夠…

C語言高級編程:一文讀懂數據結構的四大邏輯與兩大存儲

各類資料學習下載合集 ??https://pan.quark.cn/s/8c91ccb5a474? 作為一名程序員,我們每天都在與“數據”打交道。但你是否想過,這些數據在計算機中是如何被“整理”和“安放”的?為什么有些操作快如閃電,而有些則慢如蝸牛? 答案就藏在數據結構之中。 如果說算法是…

MySQL問題4

MySQL中varchar和char的區別 在 MySQL 中,VARCHAR 和 CHAR 都是用于存儲字符串類型的字段,但它們在存儲方式、性能、適用場景等方面存在明顯區別:1. 存儲方式類型說明CHAR(n)定長字符串,始終占用固定 n 個字符空間。不足的會自動在…

Web3 出海香港 101 |BuildSpace AMA 第一期活動高亮觀點回顧

香港政府在 2022-2023 年之間已經開始布局 Web3,由香港政府全資擁有的數碼港也進行了持續兩年多的深耕。目前數碼港已有接近 300 家企業入駐于此,包括 Animoca Brands、HashKey Group、CertiK 等行業知名獨角獸公司。此外,如 Cobo、OneKey、D…

LTE CA和NR CA的區別和聯系

LTE CA(Carrier Aggregation)和NR CA(New Radio Carrier Aggregation)都是載波聚合技術,它們的核心目標都是通過組合多個頻段的帶寬來提高數據傳輸速率,增強無線網絡的吞吐量。盡管它們的功能相似&#xff…

VBA 中的 Excel 工作表函數

一、引言 在使用VBA進行Excel自動化處理時,我們經常需要調用Excel內置的工作表函數來完成復雜的計算或數據處理任務。然而,很多VBA初學者并不清楚如何正確地在VBA中調用這些函數,甚至重復造輪子。本文將從基礎到進階,系統介紹如何…

老年公寓管理系統設計與實現(代碼+數據庫+LW)

摘要 隨著老齡化社會的不斷發展,老年人群體的生活質量和管理需求逐漸引起社會的廣泛關注。為了提高老年公寓的管理效率與服務質量,開發了一種基于SpringBoot框架的老年公寓管理系統。該系統充分利用了SpringBoot框架的快速開發優勢,結合現代…