# 利用遷移學習優化食物分類模型:基于ResNet18的實踐

利用遷移學習優化食物分類模型:基于ResNet18的實踐

在深度學習的眾多應用中,圖像分類一直是一個熱門且具有挑戰性的領域。隨著研究的深入,我們發現利用預訓練模型進行遷移學習是一種非常有效的策略,可以顯著提高模型的性能,尤其是在數據量有限的情況下。在這篇文章中,我們將探討如何將ResNet18模型遷移到食物分類項目中,并通過一系列技術優化模型性能。

一、遷移學習的背景

遷移學習是一種機器學習技術,它允許模型在一個任務上訓練獲得的知識應用到另一個相關任務上。在圖像分類領域,遷移學習尤其有效,因為不同類別的圖像往往共享一些通用的特征。
在這里插入圖片描述

二、項目概述

本項目的目標是構建一個能夠準確分類食物圖像的模型。我們選擇了ResNet18作為基礎模型,因為它在多個圖像分類任務上都表現出色。通過遷移學習,我們可以利用ResNet18在ImageNet數據集上預訓練的權重,加速模型的收斂并提高分類準確率。
在這里插入圖片描述
在這里插入圖片描述

三、模型遷移

1. 加載預訓練模型

我們首先加載了ResNet18的預訓練模型,并將其所有參數設置為不需要梯度更新,這樣可以防止在訓練過程中改變這些預訓練的權重。

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():param.requires_grad = False

2. 修改全連接層

由于我們的食物分類任務有20個類別,因此我們需要修改ResNet18的最后一個全連接層,以輸出20個類別的預測。

in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

3. 選擇性更新參數

在遷移學習中,我們通常只更新模型的最后幾層參數。在我們的案例中,我們只更新了全連接層的參數。

params_to_update= []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

四、數據準備與增強

為了提高模型的泛化能力,我們對訓練數據進行了一系列的增強操作,包括隨機旋轉、裁剪、翻轉和灰度化等。

data_transforms = {'train': transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(244),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([244, 244]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

五、模型訓練與評估

我們使用交叉熵損失函數和Adam優化器進行模型訓練,并采用學習率調度器來動態調整學習率。

optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

在每個訓練周期結束后,我們在測試集上評估模型的性能,并記錄最佳準確率。

for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最優訓練結果為:', best_acc)

完整代碼

import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import torchvision.models as models
from torch import nn'''將resnet18模型遷移到食物分類項目中'''
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)for param in resnet_model.parameters():print(param)param.requires_grad = Falsein_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)params_to_update= []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)# 數據增強
data_transforms = {'train': transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 隨機旋轉,-45到45度transforms.CenterCrop(244),#從中心裁剪240*240transforms.RandomHorizontalFlip(p=0.5),  # 隨機水平翻轉transforms.RandomVerticalFlip(p=0.5),  # 隨機垂直翻轉# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),  # 轉換為灰度圖transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([244, 244]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):  #food_dataset是自己創建的類名稱,可以改為你需要的名稱def __init__(self, file_path, transform=None): #類的初始化,解析數據文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中圖片的路徑保存在 self.imgs,train.txt文件中標簽保存在 sesamples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path) #圖像的路徑self.labels.append(label) #標簽,還不是tensor# 初始化:把圖片目錄加載到selfdef __len__(self):return len(self.imgs)def __getitem__(self, idx):image=Image.open(self.imgs[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))return image,label#training_data包含了本次需要訓練的全部數據集
training_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度學習\trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度學習\testda.txt', transform=data_transforms['valid'])#training_data需要具備索引的功能,還需要確保數據是tensor
train_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")model = resnet_model.to(device)  # 將剛剛定義的模型傳入到GPU中def train(dataloader, model, loss_fn, optimizer):  # 傳入參數 打包的數據,卷積模型,損失函數,優化器model.train()  # 表示模型開始訓練batch_size_num = 1for x, y in dataloader:  # 遍歷打包的圖片及其對應的標簽,其中batch為每一個數據的編號x, y = x.to(device), y.to(device)  # 把訓練數據集和標簽傳入cpu或GPUpred = model.forward(x)  # 自動初始化 W權值loss = loss_fn(pred, y)  # 傳入模型訓練結果的預測值和真實值,通過交叉熵損失函數計算損失值L0optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向傳播計算得到每個參數的梯度optimizer.step()  # 根據梯度更新網絡參數loss = loss.item()  # 獲取損失值if batch_size_num % 100 == 0:print(f"loss: {loss:>7f}[number:{batch_size_num}]")  # 打印損失值,右對齊,長度為7batch_size_num += 1  # 右下方傳入的參數,表示訓練輪數best_acc =0
def test(dataloader, model, loss_fn):  # 定義一個test函數,用于測試模型性能global best_acc  # 定義一個全局變量size = len(dataloader.dataset)  # 返回打包的圖片總數num_batches = len(dataloader)  # 返回打包的包的個數model.eval()  # 表示模型進入測試模式test_loss, correct = 0, 0  # 初始化兩個值,一個用來存放總體損失值,一個存放預測準確的個數with torch.no_grad():  # 一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()時可以減少for x, y in dataloader:  # 遍歷數據加載器中測試集圖片的圖片及其標簽x, y = x.to(device), y.to(device)  # 傳入GPUpred = model.forward(x)  # 前向傳播,返回預測結果test_loss += loss_fn(pred, y).item()  # 計算所有的損失值的和,item表示將tensor類型值轉化為python標量correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # 判斷預測的值是等于真實值,返回布爾值,將其轉換為0和1,然后求和# a = (pred.argmax(1)== y)  dim=1表示每一行中的最大值對應的索引號,dim=日表示每 b=(pred.argmax(1)==y).type(torch.float)test_loss /= num_batches  # 總體損失值除以數據條數得到平均損失值correct /= size  # 求準確率print(f"Test result:in Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")  # 表示準確率機器對應的損失值acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:  # 如果新訓練得到的準確率大于前面已經求出來的準確率best_acc = correct  # 將新的準確率傳入值best_accloss_fn = nn.CrossEntropyLoss()  # 創建交叉熵損失雨數對象,因為食物的類別是20
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 創建一個優化器,SGD為隨機梯度下降,Adam為一種自適應優化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)#調整學習率函數"訓練模型"
epochs = 50acc_s =[]
loss_s=[]
for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
#在每個epoch的訓練中,使用scheduler.step()語句進行學習率更新
print('最優訓練結果為:',best_acc)

運行結果

在這里插入圖片描述

六、總結

通過本項目,我們成功地將ResNet18模型遷移到了食物分類任務中,并通過遷移學習顯著提高了模型的性能。這種方法不僅減少了訓練時間,還提高了模型的泛化能力。未來,我們可以嘗試更多的遷移學習策略,如使用不同的預訓練模型或調整遷移學習的比例,以進一步提升模型性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76812.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76812.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76812.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Excel提取圖片并自動上傳到文件服務器(OOS),獲取文件鏈接

Excel提取圖片并自動上傳到接口 在實際項目中,我們可能經常會遇到需要批量從Excel文件(.xlsx)中提取圖片并上傳到特定接口的場景。今天,我就詳細介紹一下如何使用Python實現這一功能,本文會手把手教你搭建一個完整的解…

jmeter利用csv進行參數化和自動斷言

1.測試數據 csv測試數據如下(以注冊接口為例) 2.jemer參數化csv設置 打開 jmeter,添加好線程組、HTTP信息頭管理器、CSV 數據文件設置、注冊請求、響應斷言、查看結果樹 1) CSV 數據文件設置 若 CSV 中數據包含中文,…

騰訊云對象存儲m3u8文件使用騰訊播放器播放

參考騰訊云官方文檔: 播放器 SDK Demo 體驗_騰訊云 重要的一步來了: 登錄騰訊云控制臺,找到對象存儲的存儲桶。 此時,再去刷新剛才創建的播放器html文件,即可看到播放畫面了。

CSS 美化頁面(五)

一、position屬性 屬性值??描述??應用場景?static默認定位方式,元素遵循文檔流正常排列,top/right/bottom/left 屬性無效?。普通文檔流布局,默認布局,無需特殊定位。relative相對定位,相對于元素原本位置進行偏…

Spring MVC 核心注解與文件上傳教程

一、RequestBody 注解詳解 1. 基本使用 作用:從 HTTP 請求體中獲取數據,適用于 POST/PUT 請求。 限制:GET 請求無請求體,不可使用該注解。 示例代碼 Controller RequestMapping("/demo01") public class Demo01Cont…

js原型鏈prototype解釋

function Person(){} var personnew Person() console.log(啊啊,Person instanceof Function);//true console.log(,Person.__proto__Function.prototype);//true console.log(,Person.prototype.__proto__ Object.prototype);//true console.log(,Function.prototype.__prot…

為您的照片提供本地 AI 視覺:使用 Llama Vision 和 ChromaDB 構建 AI 圖像標記器

有沒有花 20 分鐘瀏覽您的文件夾以找到心中的特定圖像或屏幕截圖?您并不孤單。 作為工作中的產品經理,我總是淹沒在競爭對手產品的屏幕截圖、UI 靈感以及白板會議或草圖的照片的海洋中。在我的個人生活中,我總是捕捉我在生活中遇到的事物&am…

Kafka消費者端重平衡流程

重平衡的完整流程需要消費者 端和協調者組件共同參與才能完成。我們先從消費者的視角來審視一下重平衡的流程。在消費者端,重平衡分為兩個步驟:分別是加入組和等待領導者消費者(Leader Consumer)分配方案。這兩個步驟分別對應兩類…

2025年五大ETL數據集成工具推薦

ETL工具作為打通數據孤島的核心引擎,直接影響著企業的決策效率與業務敏捷性。本文精選五款實戰型ETL解決方案,從零門檻的國產免費工具到國際大廠企業級平臺,助您找到最適合的數據集成利器。 一、谷云科技ETLCloud:國產數據集成工…

PageIndex:構建無需切塊向量化的 Agentic RAG

引言 你是否對長篇專業文檔的向量數據庫檢索準確性感到失望?傳統的基于向量的RAG系統依賴于語義相似性而非真正的相關性。但在檢索中,我們真正需要的是相關性——這需要推理能力。當處理需要領域專業知識和多步推理的專業文檔時,相似度搜索常…

ubuntu20.04 遠程桌面Xrdp方式

1,Ubuntu 安裝Xrdp 方法 1.1,安裝xrdp sudo apt install xrdp 1.2,檢查xrdp狀態 sudo systemctl status xrdp 1.3,加入ssl-cert sudo adduser xrdp ssl-cert 1.4,重啟xrdp服務 sudo systemctl restart xrdp 最后…

Java學習手冊:RESTful API 設計原則

一、RESTful API 概述 REST(Representational State Transfer)即表述性狀態轉移,是一種軟件架構風格,用于設計網絡應用程序。RESTful API 是符合 REST 原則的 Web API,通過使用 HTTP 協議和標準方法(GET、…

Spring Boot 核心注解全解:@SpringBootApplication背后的三劍客

大家好呀!👋 今天我們要聊一個超級重要的Spring Boot話題 - 那個神奇的主類注解SpringBootApplication!很多小伙伴可能每天都在用Spring Boot開發項目,但你真的了解這個注解背后的秘密嗎?🤔 別擔心&#x…

weibo_har鴻蒙微博分享,單例二次封裝,鴻蒙微博,微博登錄

weibo_har鴻蒙微博分享,單例二次封裝,鴻蒙微博 HarmonyOS 5.0.3 Beta2 SDK,原樣包含OpenHarmony SDK Ohos_sdk_public 5.0.3.131 (API Version 15 Beta2) 🏆簡介 zyl/weibo_har是微博封裝使用,支持原生core使用 &a…

tomcat集成redis實現共享session

中間件&#xff1a;Tomcat、Redis、Nginx jar包要和tomcat相匹配 jar包&#xff1a;commons-pool2-2.2.jar、jedis-2.5.2.jar、tomcat-redis-session-manage-tomcat7.jar 配置Tomcat /conf/context.xml <?xml version1.0 encodingutf-8?> <!--Licensed to the A…

JavaScript 擴展Array類方法實現數組求和

題目描述&#xff1a;使用原型對象擴展Array類&#xff0c;實現返回數字型數組的和 <script>const arr [1,2,3,4,5,6]Array.prototype.sum function(){return this.reduce((prev,item)>prev item,0)}console.log(arr.sum())</script>求和函數中this 指向調用…

中間件--ClickHouse-11--部署示例(Linux宿主機部署,Docker容器部署)

一、Linux宿主機部署 1、環境準備 操作系統&#xff1a;推薦使用 CentOS 7/8 或 Ubuntu 18.04/20.04。硬件要求&#xff1a; 至少 2 核 CPU 和 4GB 內存。足夠的磁盤空間&#xff08;根據數據量評估&#xff09;。CPU需支持SSE4.2指令集&#xff08;可通過以下命令檢查&#…

鴻蒙NEXT開發權限工具類(申請授權相關)(ArkTs)

import abilityAccessCtrl, { Permissions } from ohos.abilityAccessCtrl; import { bundleManager, common, PermissionRequestResult } from kit.AbilityKit; import { BusinessError } from ohos.base; import { ToastUtil } from ./ToastUtil;/*** 權限工具類&#xff08;…

LVGL學習(二)(lv_label,lv_btn)

3-1_標簽(lv_label) 一、標簽的組成&#xff08;盒子模型&#xff09;?? 標簽由三個核心模塊構成&#xff0c;類似便簽紙的??分層設計??&#xff1a; ??LV_PART_MAIN&#xff08;主體層&#xff09;?? ??功能??&#xff1a;相當于便簽紙的"紙面"&…

深度剖析神經網絡:從基礎原理到面試要點(二)

引言 在人工智能蓬勃發展的今天&#xff0c;神經網絡作為其核心技術之一&#xff0c;廣泛應用于圖像識別、自然語言處理、語音識別等眾多領域。深入理解神經網絡的數學模型和結構&#xff0c;對于掌握人工智能技術至關重要。本文將對神經網絡的關鍵知識點進行詳細解析&#xf…