預訓練--微調

預訓練–微調

一個很簡單的道理,如果我們的模型是再ImageNet下訓練的,那么這個模型一定是會比較復雜的,意思就是這個模型可以識別到很多種類別的即泛化能力很強,但是如果要它精確的識別是否某種類別,它的表現可能就不佳了,因此,我們需要在原來的基礎上再對特定的我們需要識別的類別進行重新訓練,微調原來網絡結構中的參數,此時模型還是可以抽取較通用的圖像特征。
在這里插入圖片描述
參考自https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter09_computer-vision/9.2_fine-tuning
當目標數據集遠小于源數據集時,微調有助于提升模型的泛化能力。

熱狗識別

源數據集是ImageNet,超過1000萬個圖像和1000類物體,熱狗數據集包含1400個正類圖像和其他多種負類圖像
最開始還是導入所需要的庫以及設置cuda

import torch
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os
import d2lzh_pytorch as d2l
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下載數據集https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/hotdog.zip
我直接放在了我的默認路徑下,讀數據如下

train_imgs = ImageFolder("hotdog/train")
test_imgs = ImageFolder("hotdog/test")

然后我們觀察一下數據集,可以看到大小,寬高比各不同

# 前八張正類圖像和最后八張負類圖像,可以看到寬高比、大小各不同
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [test_imgs[-1-i][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs,2, 8, scale=2)

在這里插入圖片描述
接下來就是訓練時,我們先從圖像中隨機裁剪一塊區域,然后將該區域縮放成224*224的圖像進行輸入,測試時,我們將圖像的高和寬均縮放為256像素,然后從中裁剪出高、寬均為224的中心區域作為輸入,此外對RGB三通道作標準化,每個數值減去通道的平均值,再除以標準差需要注意的是,在使用預訓練模型時,一定要和預訓練時作同樣的預處理。 如果你使用的是torchvision的models,
那就要求: All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
如果你使用的是pretrained-models.pytorch倉庫,請務必閱讀其README,其中說明了如何預處理。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([#transforms.Resize(size=256),  # 是將最小邊調整到256#transforms.CenterCrop(size=224),transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

需要注意的是,首先我有最開始有兩點疑惑

  1. 為什么不能需要從圖像中隨機裁剪一塊區域,然后將該區域縮放成224*224的圖像進行輸入。然后我測試了一下,如果不這樣做的話,那么泛化能力會比較差
  2. 如果非要這么做,那么可不可以直接transforms.Resize(size=224)?不可以的,transforms.Resize(size=224)是把最短的邊變為224,寬高比沒變,那么這樣就會導致圖像的尺寸不一樣,后面自然會報錯,所以需要先transforms.Resize(size=256),然后transforms.CenterCrop(size=224)

之后我們使用在ImageNet上預訓練的ResNet18,pretrained=True,自動下載預訓練參數
不管你是使用的torchvision的models還是pretrained-models.pytorch倉庫,默認都會將預訓練好的模型參數下載到你的home目錄下.torch文件夾。
你可以通過修改環境變量$TORCH_MODEL_ZOO來更改下載目錄

pretrained_net = models.resnet18(pretrained=True)

修改最后一層

pretrained_net.fc = nn.Linear(512, 2)

接下來設置訓練的參數,由于除了最后一層,之前的參數都經過預訓練,所以我們學習率調小一點,最后的fc層是初始化過的,于是我們學習率調大一點

output_params = list(map(id, pretrained_net.fc.parameters()))  # fc層
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())  # 除了fc層
lr = 0.01 # 用來更新特征層
# fc層是lr * 10
optimizer = optim.SGD([{"params":feature_params},{"params":pretrained_net.fc.parameters(), "lr":lr*10}
] ,lr = lr, weight_decay=0.001)

在之后就是訓練了

def train_fine_tuning(net, optimizer, batch_size=64, num_epochs=5):train_iter = DataLoader(ImageFolder("hotdog/train", transform=train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder("hotdog/test", transform=test_augs), batch_size, shuffle=False)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
train_fine_tuning(pretrained_net, optimizer)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/207704.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/207704.shtml
英文地址,請注明出處:http://en.pswp.cn/news/207704.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

07-2 Python模塊和命名空間

1. 模塊 概念:其實就是一個Python文件,正常文件有的變量,函數,類,模塊都有 功能:模塊可以被其它程序引入,以使用該模塊中的函數等功能。 示例:test-module.py調用mymodule.py模塊中的now_time…

充電樁IC

充電樁IC 電子元器件百科 文章目錄 充電樁IC前言一、充電樁IC是什么二、充電樁IC的類別三、充電樁IC的應用實例四、充電樁IC的工作原理總結前言 充電樁IC的設計和功能會根據不同的充電協議和市場需求進行調整和定制。目前市場上有許多不同型號和廠家的充電樁IC可供選擇,以滿足…

一篇文章帶你快速入門 Vue 核心語法

一篇文章帶你快速入門 Vue 核心語法 一、為什么要學習Vue 1.前端必備技能 2.崗位多,絕大互聯網公司都在使用Vue 3.提高開發效率 4.高薪必備技能(Vue2Vue3) 二、什么是Vue 概念:Vue (讀音 /vju?/,類似于 view) …

Mysql 日期函數大全

一、時間函數 (一)、獲取當前時間 1、NOW() 獲取當前日期和時間,在程序一開始執行便拿到時間 返回格式 YYYY-MM-DD hh:mm:ss eg: NOW() 得到 2023-12-03 12:20:02 NOW(),SLEEP(2),NOW() 得到 2023-12-03 12:20:02 | 0 | 2023-…

目標檢測——OverFeat算法解讀

論文:OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks 作者:Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus, Yann LeCun 鏈接:https://arxiv.org/abs/1312.6229 文章…

Go語言-讓我印象深刻的13個特性

我們正在加速進入云原生時代,Go語言作為云原生的一塊基石,確有它的獨到之處。本文介紹Go語言的幾個讓我印象深刻的特性。 1、兼顧開發效率和性能 Go語言兼顧開發效率和性能。可以像Python那樣有很快的開發速度,也可以像C那樣有很快的執行速…

SpringAOP專欄二《原理篇》

上一篇SpringAOP專欄一《使用教程篇》-CSDN博客介紹了SpringAop如何使用,這一篇文章就會介紹Spring AOP 的底層實現原理,并通過源代碼解析來詳細闡述其實現過程。 前言 Spring AOP 的實現原理是基于動態代理和字節碼操作的。不了解動態代理和字節碼操作…

【C語言】函數遞歸詳解(一)

目錄 1.什么是遞歸: 1.1遞歸的思想: 1.2遞歸的限制條件: 2.遞歸舉例: 2.1舉例1:求n的階乘: 2.1.1 分析和代碼實現: 2.1.2圖示遞歸過程: 2.2舉例2:順序打印一個整數的…

機器學習---集成學習的初步理解

1. 集成學習 集成學習(ensemble learning)是現在非常火爆的機器學習方法。它本身不是一個單獨的機器學 習算法,而是通過構建并結合多個機器學習器來完成學習任務。也就是我們常說的“博采眾長”。集 成學習可以用于分類問題集成,回歸問題集成&#xff…

多線程并發Ping腳本

1. 前言 最近需要ping地址,還是挺多的,就使用python搞一個ping腳本,記錄一下,以免丟失了。 2. 腳本介紹 首先檢查是否存在True.txt或False.txt文件,并在用戶確認后進行刪除,然后從IP.txt的文件中讀取IP地…

CSS——sticky定位

1. 大白話解釋sticky定位 粘性定位通俗來說,它就是相對定位relative和固定定位fixed的結合體,它的觸發過程分為三個階段 在最近可滾動容器沒有觸發滑動之前,sticky盒子的表現為相對定位relative【第一階段】, 但當最近可滾動容…

【MATLAB】tvfEMD信號分解+FFT+HHT組合算法

有意向獲取代碼,請轉文末觀看代碼獲取方式~也可轉原文鏈接獲取~ 1 基本定義 TVFEMDFFTHHT組合算法是一種結合了總體變分模態分解(TVFEMD)、傅里葉變換(FFT)和希爾伯特-黃變換(HHT)的信號分解方…

vivado時序方法檢查8

TIMING-30 &#xff1a; 生成時鐘所選主源管腳欠佳 生成時鐘 <clock_name> 所選的主源管腳欠佳 &#xff0c; 時序可能處于消極狀態。 描述 雖然 create_generated_clock 命令允許您指定任意參考時鐘 &#xff0c; 但是生成時鐘應引用在其直接扇入中傳輸的時鐘。此…

電子學會C/C++編程等級考試2021年06月(五級)真題解析

C/C++等級考試(1~8級)全部真題?點這里 第1題:數字變換 給定一個包含5個數字(0-9)的字符串,例如 “02943”,請將“12345”變換到它。 你可以采取3種操作進行變換 1. 交換相鄰的兩個數字 2. 將一個數字加1。如果加1后大于9,則變為0 3. 將一個數字加倍。如果加倍后大于…

JS--異步的日常用法

目錄 JS 異步編程并發&#xff08;concurrency&#xff09;和并行&#xff08;parallelism&#xff09;區別回調函數&#xff08;Callback&#xff09;GeneratorPromiseasync 及 await常用定時器函數 JS 異步編程 并發&#xff08;concurrency&#xff09;和并行&#xff08;p…

Python中一些有趣的例題

下面會寫一些基礎的例題&#xff0c;有興趣的自己也可以練練手&#xff01; 1.假設手機短信收到的數字驗證碼為“278902”&#xff0c;編寫一個程序&#xff0c;讓用戶輸入數字驗證碼&#xff0c;如果數字驗證碼輸入正確&#xff0c;提示“支付成功”&#xff1b;否則提示“數…

Python configparser 模塊:優雅處理配置文件的得力工具

更多資料獲取 &#x1f4da; 個人網站&#xff1a;ipengtao.com 配置文件在軟件開發中扮演著重要的角色&#xff0c;而Python中的 configparser 模塊提供了一種優雅而靈活的方式來處理各種配置需求。本文將深入介紹 configparser 模塊的各個方面&#xff0c;通過豐富的示例代碼…

嵌入式雜記 - MDK的Code, RO-data , RW-data, ZI-data意思

嵌入式雜記 - Keil的Code, RO-data , RW-data, ZI-data意思 MDK中的數據分類MCU中的內部存儲分布MDK中數據類型存儲Code代碼段例子 RO-data 只讀數據段例子 RW-data 可讀寫數據段例子 ZI-data 清零數據段例子 在嵌入式開發中&#xff0c;我們經常都會使用一些IDE&#xff0c;例…

Hadoop學習筆記(HDP)-Part.17 安裝Spark2

目錄 Part.01 關于HDP Part.02 核心組件原理 Part.03 資源規劃 Part.04 基礎環境配置 Part.05 Yum源配置 Part.06 安裝OracleJDK Part.07 安裝MySQL Part.08 部署Ambari集群 Part.09 安裝OpenLDAP Part.10 創建集群 Part.11 安裝Kerberos Part.12 安裝HDFS Part.13 安裝Ranger …

Web前端 ---- 【Vue】Vuex的使用(輔助函數、模塊化開發)

目錄 前言 Vuex是什么 Vuex的配置 安裝vuex 配置vuex文件 Vuex核心對象 actions mutations getters state Vuex在vue中的使用 輔助函數 Vuex模塊化開發 前言 本文介紹一種新的用于組件傳值的插件 —— vuex Vuex是什么 Vuex 是一個專為 Vue.js 應用程序開發的狀態…