深度學習:遷移學習

遷移學習

標題1.什么是遷移學習

遷移學習(Transfer Learning)是一種機器學習方法,就是把為任務 A 開發 的模型作為初始點,重新使用在為任務 B 開發模型的過程中。遷移學習是通過 從已學習的相關任務中轉移知識來改進學習的新任務,雖然大多數機器學習算 法都是為了解決單個任務而設計的,但是促進遷移學習的算法的開發是機器學 習社區持續關注的話題。 遷移學習對人類來說很常見,例如,我們可能會發現 學習識別蘋果可能有助于識別梨,或者學習彈奏電子琴可能有助于學習鋼琴。
找到目標問題的相似性,遷移學習任務就是從相似性出發,將舊領域 (domain)學習過的模型應用在新領域上

標題2.遷移學習的步驟

1、選擇預訓練的模型和適當的層
通常,我們會選擇在大規模圖像數據集(如ImageNet)上預訓練的模型,如VGG、ResNet等。然后,根據新數據集的特點,選擇需要微調的模型層。對于低級特征的任務(如邊緣檢測),最好使用淺層模型的層,而對于高級特征的任務(如分類),則應選擇更深層次的模型。
2、凍結預訓練模型的參數
保持預訓練模型的權重不變,只訓練新增加的層或者微調一些層,避免因為在數據集中過擬合導致預訓練模型過度擬合。
3、在新數據集上訓練新增加的層
在凍結預訓練模型的參數情況下,訓練新增加的層。這樣,可以使新模型適應新的任務,從而獲得更高的性能。
4、微調預訓練模型的層
在新層上進行訓練后,可以解凍一些已經訓練過的層,并且將它們作為微調的目標。這樣做可以提高模型在新數據集上的性能。
5、評估和測試
在訓練完成之后,使用測試集對模型進行評估。如果模型的性能仍然不夠好,可以嘗試調整超參數或者更改微調層。

標題3.遷移學習實例

該實例使用的模型是ResNet-18殘差神經網絡模型
###1. 導入必要的庫

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

這里導入了后續代碼會用到的庫,具體如下:
torch:PyTorch 深度學習框架的核心庫。
torchvision.models:包含了預訓練的模型,這里會用到 ResNet-18。
torch.nn:用于構建神經網絡的模塊。
torch.utils.data.Dataset 和 torch.utils.data.DataLoader:用于自定義數據集和加載數據。
torchvision.transforms:用于圖像的預處理。
PIL.Image:用于讀取圖像。
numpy:用于數值計算。
###2. 加載預訓練模型并修改全連接層

resnet_model= models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():print(param)param.requires_grad=False
in_features=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_features,20)
params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

加載預訓練的 ResNet-18 模型。
把模型中所有參數的 requires_grad 設置為 False,也就是凍結這些參數,使其在訓練時不更新。
獲取原模型全連接層的輸入特征數,然后將全連接層替換為一個新的全連接層,輸出維度為 20。
收集所有 requires_grad 為 True 的參數,這些參數會在訓練時更新。
###3. 定義圖像預處理變換

data_transforms = {'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),'valid':transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

定義了兩個圖像預處理的組合變換,分別用于訓練集和驗證集。
訓練集的變換包含了數據增強操作,像隨機旋轉、水平翻轉、垂直翻轉等。
驗證集的變換只包含了調整大小、轉換為張量和標準化操作。

4. 自定義數據集類

class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return  len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label

自定義了一個 food_dataset 類,繼承自 torch.utils.data.Dataset。 init 方法:解析包含圖像路徑和標簽的文本文件,把圖像路徑和標簽分別存到 self.imgs 和 self.labels 中。
len 方法:返回數據集的大小。
getitem 方法:根據索引讀取圖像,對圖像進行預處理,將標簽轉換為張量,然后返回圖像和標簽。

5. 創建數據集和數據加載器

training_data = food_dataset(file_path='./trainbig.txt',transform=data_transforms['train'])
test_data = food_dataset(file_path='./testbig.txt',transform=data_transforms['valid'])
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)

創建訓練集和測試集的數據集對象。
創建訓練集和測試集的數據加載器,設置批量大小為 64,并且打亂數據
###6. 配置訓練設備、損失函數、優化器和學習率調度器

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model=resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)

選擇合適的訓練設備(GPU 或 CPU)。
把模型移動到所選設備上。
定義交叉熵損失函數。
定義 Adam 優化器,只對之前收集的需要更新的參數進行優化。
定義學習率調度器,每 5 個 epoch 將學習率乘以 0.5。
###7. 定義訓練和測試函數

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()def test(dataloader, model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches =len(dataloader)model.eval()test_loss,correct =0,0with torch.no_grad():for X, y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss+=loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy:{(100 * correct)}%, Avg loss: {test_loss}")acc_s.append(correct)loss_s.append(test_loss)if correct>best_acc:best_acc=correct

train 函數:將模型設置為訓練模式,遍歷訓練數據加載器,計算損失,反向傳播并更新模型參數。
test 函數:將模型設置為評估模式,遍歷測試數據加載器,計算測試集的準確率和平均損失,記錄最佳準確率。
8. 訓練模型并保存

epochs = 20
acc_s = []
loss_s =[]
for t in range(epochs):print(f"Epoch {t + 1}\n-----------")train(train_dataloader, model,loss_fn, optimizer)scheduler.step()test(test_dataloader,model,loss_fn)
print('最優訓練結果為:',best_acc)
torch.save(model.state_dict(), 'food_classification_model.pt')

訓練模型 20 個 epoch。
每個 epoch 結束后,更新學習率并進行測試。
打印最優訓練結果。
保存模型的參數到 food_classification_model.pt 文件中。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/903344.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/903344.shtml
英文地址,請注明出處:http://en.pswp.cn/news/903344.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Rabbitmq下載和安裝(Windows系統,百度網盤)

一.下載安裝Erlang 1.百度云下載 鏈接:https://pan.baidu.com/s/1k_U25KKngEf1iXWD1ANOeg 提取碼:8ilc 2.安裝 傻瓜式安裝 直接下一步 選擇自己要安裝的路徑 3.配置環境變量 增加變量名為:ERLANG_HOME 變量值填寫自己的安裝路徑&#x…

(一)Linux的歷史與環境搭建

【知識預告】 Linux背景介紹Linux操作系統特性Linux的應用場景Linux的發行版本搭建Linux環境 1 Linux背景介紹 1.1 什么是Linux? Linux是一種自由、開源的操作系統。嚴格來說,它是基于類Unix設計思想,旨在為用戶提供穩定、安全、高效的計…

光流法:從傳統方法到深度學習方法

1 光流法簡介 光流(Optical Flow)是指圖像中像素灰度值隨時間的變化而產生的運動場。 簡單來說,它描述了圖像中每個像素點的運動速度和方向。 光流法是一種通過分析圖像序列中像素灰度值來計算光流的方法。對于圖像數據計算出來的光流是一個二…

解決ssh拉取服務器數據,要多次輸入密碼的問題

問題在于,每次循環調用 rsync 都是新開一個連接,所以每次都需要輸入一次密碼。為了只輸入一次密碼,有以下幾種方式可以解決: ? 推薦方案:設置 SSH 免密登錄 最穩最安全的方式是:配置 SSH 免密登錄&#x…

web技術與Nginx網站服務

目錄 一. web基礎 1. 域名概念 2. Hosts 文件 3. DNS 4. 域名注冊 5. 網頁與 HTML 二. 網頁概述 1. HTML 概述 2. HTML 基本標簽 3. 網站和主頁 三. 靜態網頁與動態網頁 1. 靜態網頁 2. 動態網頁 3. 動態網頁語言 四. HTTP 協議 1. HTTP 協議概述 2. HTTP …

信創系統資產清單采集腳本:主機名+IP+MAC 一鍵生成 CSV

原文鏈接:信創系統資產清單采集腳本:主機名IPMAC 一鍵生成 CSV Hello,大家好啊!今天給大家帶來一篇在信創終端操作系統上自動批量采集主機名、IP 和 MAC 并導出為 CSV 表格的實戰文章!本方案使用 sshpass 和 Bash 腳本…

【dify+docker安裝教程】

目錄 一、dify安裝包下載 二、運行環境配置 1、下載docker 2、安裝 2.1 新建文件夾 2.2 安裝 2.3 命令安裝 3.下載完成后需要重啟電腦,注意保存文檔!!注意保存!!注意!!(血的教…

HTML 地理定位(Geolocation)教程

HTML 地理定位(Geolocation)教程 簡介 HTML5 的 Geolocation API 允許網頁應用獲取用戶的地理位置信息。這個功能可用于提供基于位置的服務,如導航、本地搜索、天氣預報等。本教程將詳細介紹如何在網頁中實現地理定位功能。 工作原理 瀏覽器可以通過多種方式確定…

協作開發攻略:Git全面使用指南 — 引言

協作開發攻略:Git全面使用指南 — 引言 Git 是一種分布式版本控制系統,用于跟蹤文件和目錄的變更。它能幫助開發者有效管理代碼版本,支持多人協作開發,方便代碼合并與沖突解決,廣泛應用于軟件開發領域。 文中內容僅限技…

畢業設計-基于預訓練語言模型與深度神經網絡的Web入侵檢測系統

項目技術說明 基于預訓練語言模型與深度神經網絡的Web入侵檢測系統,通過預訓練模型CodeBert分詞,將分詞輸入給BiGRU的深度學習模型訓練。通過sniff函數實時捕獲http流量信息,將流量信息輸入給模型進行檢測,模型可以檢測的類別有S…

[計算機科學#4]:二進制如何塑造數字世界(0和1的力量)

【核知坊】:釋放青春想象,碼動全新視野。 我們希望使用精簡的信息傳達知識的骨架,啟發創造者開啟創造之路!!! 內容摘要: 二進制是計算機世界的基石,數學是世界的…

JUC中各種鎖機制的應用和原理及死鎖問題定位

JUC中各種鎖機制的應用和原理及死鎖問題定位 在互聯網大廠Java求職者的面試中,經常會被問到關于JUC(Java Util Concurrency)中的各種鎖機制及其應用和原理的問題。本文通過一個故事場景來展示這些問題的實際解決方案。 第一輪提問 面試官&…

配置Ubuntu18.04中的Qt Creator為中文(圖文詳解)

配置Qt Creator為中文 1、前言2、先設置Ubuntu系統語言為中文3、配置Qt Creator中文環境2.1 IBus輸入法(方法一)2.2、測試IBus輸入法2.21IBus輸入法終端中測試2.2.2IBus輸入法Qt Creator中測試 2.3、Fcitx輸入法(方法二)2.3.1安裝…

高性能服務器配置經驗指南3——安裝服務器可能遇到的問題及解決方法

文章目錄 1、重裝系統后VScode遠程連接失敗問題2、XRDP連接黑屏問題1. 打開文件2. 添加配置3. 重啟xrdp服務 3、VScode遠程免密連接問題4、Vim編輯文件時出現不同用戶沖突編輯的問題 在完成 服務器基本配置和 深度學習環境準備后,大家應該就可以正常使用服務器了&…

PyQt6基礎_QThread

目錄 前置 代碼: 運行 正常運行 QThread運行報錯 視頻 前置 1 PySide6.QtCore.QThread - Qt for Python QThread官方文檔 2 長時間任務可以放到QThread中執行,避免占用主線程導致界面卡頓無法操作 代碼: import traceback,sys fro…

Spring Boot 應用運行指南

🚀 Spring Boot 應用運行指南 ?? 使用 Maven 🔧 運行命令 $ mvn spring-boot:run? 啟動效果 . ____ _ __ _ _/\\ / ____ __ _ _(_)_ __ __ _ \ \ \ \ ( ( )\___ | _ | _| | _ \/ _ | \ \ \ \\\/ ___)| |_)| | | | | || (_…

jeecgboot 3.8.0 集成knife4j問題一文解決

問題描述: ? 在cloud環境下,若應用系統配置了context-path,則無法通過網關進入后臺接口管理系統 原因分析: ? 查看請求信息發現少拼接了系統的context-path,導致無法正確請求到數據。直接使用正確的地址可以正常通過網關訪問。故此確定為集成knife4j的問題。 解決辦法…

【Flutter】Flutter + Unity 插件結構與通信接口封裝

關聯文檔:【方案分享】Flutter Unity 跨平臺三維渲染架構設計全解:插件封裝、通信機制與熱更新機制—— 支持 Android/iOS/Web 的 3D 內容嵌入與遠程資源管理,助力 XR 項目落地 —— 支持 Android/iOS/Web 的 3D 內容嵌入與遠程資源管理&…

推薦 1 款 9.3k stars 的全景式開源數據分析與可視化工具

Orama 是一個開源的數據分析與可視化項目,由askorama團隊開發和維護。該項目旨在為用戶提供一套強大而易用的工具集,幫助用戶輕松處理和理解大規模數據,通過創建交互式且引人入勝的數據可視化圖表,揭示隱藏在數據背后的深層次洞察…

關于windows API 的鍵鼠可控可測

相關函數解釋 GetAsyncKeyState 是 Windows API 中的一個函數,用于判斷某個虛擬鍵是否被按下。GetAsyncKeyState(VK_ESCAPE) 專門用于檢測 Esc 鍵的狀態。下面為你詳細介紹其用法: 函數原型 cpp SHORT GetAsyncKeyState( int vKey ); 參數 vKey&a…