[pytorch、學習] - 9.2 微調

參考

9.2 微調

在前面得一些章節中,我們介紹了如何在只有6萬張圖像的Fashion-MNIST訓練數據集上訓練模型。我們還描述了學術界當下使用最廣泛規模圖像數據集ImageNet,它有超過1000萬的圖像和1000類的物體。然而,我們平常接觸到數據集的規模通常在這兩者之間。

假設我們想從圖像中識別出不同種類的椅子,然后將購買鏈接推薦給用戶。一種可能的方法是先找出100種常用的椅子,為椅子拍攝1000張不同角度的圖像,然后在收集到的圖像數據集上訓練一個分類模型。這個椅子數據集雖然可能比Fashion-MNIST數據集要龐大,但樣本仍然不及ImageNet數據集中樣本數的十分之一。這可能會導致適用于ImageNet數據集的復雜模型在這個椅子數據集上過擬合。同時,因為數據量有限,但其成本仍熱不可忽略。

另一種解決辦法是應用遷移學習(transfer learning),將從源數據集學到的知識遷移到目標數據集上。例如,雖然ImageNet數據集的圖像大多跟椅子無關,但在該數據集上訓練的模型可以抽取較通用的圖像特征,從而能夠幫助識別邊緣、紋理、形狀和物體組成等。這些類似的特征對于識別椅子也可能同樣有效。

本節我們介紹遷移學習中的一種常用技術: 微調(fine tuning)。如圖9.1所示,微調由以下4步構成。

  1. 在源數據集(如ImageNet數據集)上預訓練一個神經網絡模型,即源模型。
  2. 創建一個新的神經網絡模型,即目標模型。它復制了源模型上除了輸出層外的所有模型設計及其參數。我們假設這些模型參數包含了源數據集上學習到的知識,且這些知識同樣適用于目標數據集。我們還假設源模型的輸出層跟源數據集的標簽緊密相關,因此在目標模型中不予采用。
  3. 為目標模型添加一個輸出大小為目標數據集類別個數的輸出層,并隨機初始化該層的模型參數。
  4. 在目標數據集(如椅子數據集)上訓練目標模型。我們將從頭訓練輸出層,而其余層的參數都是基于源模型的參數微調得到的。

9.2.1 熱狗識別

接下來我們來實踐一個具體的例子: 熱狗識別。我們將基于一個小數據集在ImageNet數據集上訓練好的ResNet模型進行微調。該小數據集含有數千張包含熱狗和不包含熱狗的圖像。我們使用微調得到的模型來識別一張圖像中是否包含熱狗。

首先,導入實驗所需要的包或模塊。torchvision的models包提供了常用的預訓練模型。如果希望獲取更多的預訓練模型,可以使用pretrained-models.pytorch倉庫.

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import osimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.2.1.1 獲取數據集

我們使用的熱狗數據集是從網上抓取的,它包含1400張含熱狗的正類圖像,和同樣多包含其他食品的負類圖像。各類的1000張圖像被用于訓練,其余則用于測試。

我們首先將壓縮后的數據集下載到路徑data_dir之下,然后在該路徑將下載好的數據集解壓,得到兩個文件夾hotdog/trainhotdog/test。這兩個文件夾下面均有hotdognot-hotdog兩個類別文件夾,每個類別文件夾里面是圖像文件。

data_dir = "C:/Users/1/Datasets"
os.listdir(os.path.join(data_dir, 'hotdog'))

我們創建兩個ImageFolder實例來分別讀取訓練數據集和測試數據集中的所有圖像文件

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

下面畫出前8張正類圖像和最后8張負類圖像。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i- 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

在這里插入圖片描述
在訓練時,我們先從圖像中裁剪隨機大小和隨機寬高比的一塊隨機區域,然后將該區域縮放為高和寬均為224像素的輸入。測試時,我們將圖像的高和寬均縮放為256像素,然后從中裁剪出高和寬均為224像素的中心區域作為輸入。此外,我們對RGB(紅、綠、藍)三個顏色通道的數值做標準化:每個數值減去通道所有數值的平均值,再除以該通道所有數值的標準差作為輸出。

注: 使用pretrained-models倉庫時,一定要對圖像進行相應的預處理

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406 ], std = [0.229, 0.224, 0.225])
train_augs = transforms.Compose([transforms.RandomResizedCrop(size= 224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])
test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

9.2.1.2 定義和初始化模型

我們使用在ImageNet數據集上預訓練的ResNet-18作為源模型。這里指定pretrained=True來自動下載并記載預訓練的模型參數。在第一次使用時需聯網下載模型參數

pretrained_net  = models.resnet18(pretrained=True)

打印源模型的成員變量fc。作為一個全連接層,它將ResNet最終的全局平均池化層輸出變成ImageNet數據集上1000類的輸出

print(pretrained_net.fc)

在這里插入圖片描述
可見此時pretrained_net最后的輸出個數等于目標數據集的類別數1000。所以我們應該將最后的fc修改成我們需要輸出類別數:

pretrained_net.fc = nn.Linear(512, 2)

此時,pretrained_netfc層就隨機初始化了,但是其他層依然保存著預訓練得到的參數。由于是在很大的ImageNet數據集上預訓練的,所以參數已經足夠好,因此一般只需使用較小的學習率來微調這些參數,而fc中的隨機參數一般需要更大的學習率從頭訓練。PyTorch可以方便的對模型的不同部分設置不同的學習參數,我們在下面代碼中將fc的學習率設置為已經預訓練過的部分的10倍

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())lr = 0.01
optimizer = optim.SGD([{'params': feature_params},{'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],lr = lr, weight_decay=0.001)

9.2.1.3 微調模型

我們先定義一個使用微調的訓練函數train_fine_tuning以便多次調用。

def train_fine_tuning(net, optimizer, batch_size = 128, num_epochs = 15):train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform = train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs), batch_size)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)

根據前面的設置,我們將以10倍的學習率從頭訓練目標模型的輸出層參數。

train_fine_tuning(pretrained_net, optimizer)

在這里插入圖片描述
作為對比,我們定義一個相同的模型,但將它的所有模型參數都初始化為隨機值。由于整個模型都需要從頭訓練,我們可以使用較大的學習率。

scratch_net = models.resnet18(pretrained=False, num_classes=2)
lr = 0.1
optimizer  = optim.SGD(scratch_net.parameters(), lr = lr, weight_decay = 0.001)
train_fine_tuning(scratch_net, optimizer)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250110.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250110.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250110.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Springboot默認加載application.yml原理

Springboot默認加載application.yml原理以及擴展 SpringApplication.run(…)默認會加載classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默認加載一套默認的配置文件demo.properties,讓開發人員實現“零”配置開發,但…

java 集合(Set接口)

Set接口:無序集合,不允許有重復值,允許有null值 存入與取出的順序有可能不一致 HashSet:具有set集合的基本特性,不允許重復值,允許null值 底層實現是哈希表結構 初始容量為16 保存自定義對象時,保證數據的唯…

關于mac機抓包的幾點基礎知識

1. 我使用的抓包工具為WireShark,以下操作按我當前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的區別。 2. 將mac設置為熱點:打開系統偏好設置,點擊共享: 然后點擊WIFI選項,設置WIFI名…

SpringBoot啟動如何加載application.yml配置文件

一、前言 在spring時代配置文件的加載都是通過web.xml配置加載的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多數都是通過指定路徑的文件名的形式去告訴spring該加載哪個文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小結(分支拉取、版本回退)

1. 首次(fork項目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [遠程倉庫的分支名]克隆的是主干網絡 2. 再次拉取代碼 $ git pull [master下選擇分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得進入主倉(不是自己的遠程倉)然后…

MYSQL 查看最大連接數和修改最大連接數

MySQL查看最大連接數和修改最大連接數 1、查看最大連接數show variables like %max_connections%;2、修改最大連接數set GLOBAL max_connections 200; 以下的文章主要是向大家介紹的是MySQL最大連接數的修改&#xff0c;我們大家都知道MySQL最大連接數的默認值是100, 這個數值…

阿里云服務器端口開放對外訪問權限

登陸阿里云管理控制臺 點擊自己的實例 點擊安全組配置 點擊配置規則 點擊添加安全組規則 配置出入放心&#xff0c;和開放的端口號&#xff0c;以及那些網段可以訪問&#xff0c;這里設置所有網段都可以訪問 轉自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

數據分頁功能是我們軟件系統中必備的功能&#xff0c;在持久層使用mybatis的情況下&#xff0c;pageHelper來實現后臺分頁則是我們常用的一個選擇&#xff0c;所以本文專門類介紹下。 PageHelper原理 相關依賴 <dependency><groupId>org.mybatis</groupId>&…

10-多寫一個@Autowired導致程序崩了

再是javaweb實驗六中&#xff0c;是讓我們改代碼&#xff0c;讓它跑起來&#xff0c;結果我少注釋了一個&#xff0c;導致一直報錯&#xff0c;檢查許久沒有找到&#xff0c;最后通過代碼替換逐步查找&#xff0c;才發現問題。 轉載于:https://www.cnblogs.com/zhumengdexiaoba…

Java class不分32位和64位

1、32位JDK編譯的java class在32位系統和64位系統下都可以運行&#xff0c;64位系統兼容32位程序&#xff0c;可以理解。2、無論是Linux還是Windows平臺下的JDK編譯的java class在Linux、Windows平臺下通用&#xff0c;Java跨平臺特性。3、64位JDK編譯的java class在32位的系統…

包裝對象

原文地址&#xff1a;https://wangdoc.com/javascript/ 定義 對象是JavaScript語言最主要的數據類型&#xff0c;三種原始類型的值--數值、字符串、布爾值--在一定條件下&#xff0c;也會自動轉為對象&#xff0c;也就是原始類型的包裝對象。所謂包裝對象&#xff0c;就是分別與…

[C++] 轉義序列

參考 C Primer(第5版)P36 名稱轉義序列換行符\n橫向制表符\t報警(響鈴)符\a縱向制表符\v退格符\b雙引號"反斜杠\問號?單引號’回車符\r進紙符\f

vue使用(二)

本節目標&#xff1a; 1.數據路徑的三種方式 2.{{}}和v-html的區別 1.綁定圖片的路徑 方法一&#xff1a;直接寫路徑 <img src"http://pic.baike.soso.com/p/20140109/20140109142534-188809525.jpg"> 方法二&#xff1a;在data中寫路徑&#xff0c;在…

typedef 為類型取別名

#include <stdio.h> int main() {   typedef int myint; // 為int 類型取自己想要的名字   myint a 10;   printf("%d", a);   return 0;} 其他類型的用法也是一樣的 typedef 類型 自己想要取得名字; 轉載于:https://www.cnblogs.com/hello-dummy/p/9…

【C++】如何提高Cache的命中率,示例

參考鏈接 https://stackoverflow.com/questions/16699247/what-is-a-cache-friendly-code 只是堆積&#xff1a;緩存不友好與緩存友好代碼的典型例子是矩陣乘法的“緩存阻塞”。 樸素矩陣乘法看起來像 for(i0;i<N;i) {for(j0;j<N;j) {dest[i][j] 0;for( k;k<N;i)…

springboot---整合redis

pom.xml新增 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>代碼結構如下 其中redis.yml是連接redis的配置文件&#xff0c;RedisConfig.java是java配置…

[Head First Java] - 簡單的建議程序

參考 - p481、p484 與我對接的業務層使用的是JAVA語言,因此花點時間入門java.下面幾篇博客可能都是關于java的,我覺得在工作中可能會遇到的 簡單的通信 DailyAdviceClient(客戶端程序) import java.io.*; import java.net.*;public class DailyAdviceClient{public void go()…

SQL重復記錄查詢的幾種方法

1 查找表中多余的重復記錄&#xff0c;重復記錄是根據單個字段1 select * from TB_MAT_BasicData1 2 where MATNR in ( select MATNR from TB_MAT_BasicData1 group by MATNR having count(MATNR)>1) 2.表需要刪除重復的記錄&#xff08;重復記錄保留1條&#xff09;&…

Redis 的應用場景

之前講過Redis的介紹&#xff0c;及使用Redis帶來的優勢&#xff0c;這章整理了一下Redis的應用場景&#xff0c;也是非常重要的&#xff0c;學不學得好&#xff0c;能正常落地是關鍵。 下面一一來分析下Redis的應用場景都有哪些。 1、緩存 緩存現在幾乎是所有中大型網站都在…

[Head First Java] - Swing做一個簡單的客戶端

參考 - P487 1. vscode配置java的格式 點擊左下角齒輪 -> 設置 -> 打開任意的setting.json輸入如下代碼 {code-runner.executorMap": {"java": "cd $dir && javac -encoding utf-8 $fileName && java $fileNameWithoutExt"},…