【Res模塊學習】結合CIFAR-100分類任務學習

初次嘗試訓練CIFAR-100:【圖像分類】CIFAR-100圖像分類任務-CSDN博客

1.訓練模型(MyModel.py)

import torch
import torch.nn as nnclass BasicRes(nn.Module):def __init__(self, in_cha, out_cha, stride=1, res=True):super(BasicRes, self).__init__()self.layer01 = nn.Sequential(nn.Conv2d(in_channels=in_cha, out_channels=out_cha, kernel_size=3, stride=stride, padding=1),nn.BatchNorm2d(out_cha),nn.ReLU(),)self.layer02 = nn.Sequential(nn.Conv2d(in_channels=out_cha, out_channels=out_cha, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_cha),)if res:self.res = resif in_cha != out_cha or stride != 1:  # 若x和f(x)維度不匹配:self.shortcut = nn.Sequential(nn.Conv2d(in_channels=in_cha, out_channels=out_cha, kernel_size=1, stride=stride),nn.BatchNorm2d(out_cha),)else:self.shortcut = nn.Sequential()def forward(self, x):residual = xx = self.layer01(x)x = self.layer02(x)if self.res:x += self.shortcut(residual)return x# 2.訓練模型
class cifar100(nn.Module):def __init__(self):super(cifar100, self).__init__()# 初始維度3*32*32self.Stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2),  # (32-5+2*2)/1+1=32nn.BatchNorm2d(64),nn.ReLU(),)self.layer01 = BasicRes(in_cha=64, out_cha=64)self.layer02 = BasicRes(in_cha=64, out_cha=64)self.layer11 = BasicRes(in_cha=64, out_cha=128)self.layer12 = BasicRes(in_cha=128, out_cha=128)self.layer21 = BasicRes(in_cha=128, out_cha=256)self.layer22 = BasicRes(in_cha=256, out_cha=256)self.layer31 = BasicRes(in_cha=256, out_cha=512)self.layer32 = BasicRes(in_cha=512, out_cha=512)self.pool_max01 = nn.MaxPool2d(1, 1)self.pool_max02 = nn.MaxPool2d(2)self.pool_avg = nn.AdaptiveAvgPool2d((1, 1))  # b*c*1*1self.fc = nn.Sequential(nn.Dropout(0.4),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 100),)def forward(self, x):x = self.Stem(x)x = self.pool_max01(x)x = self.layer01(x)x = self.layer02(x)x = self.pool_max02(x)x = self.layer11(x)x = self.layer12(x)x = self.pool_max02(x)x = self.layer21(x)x = self.layer22(x)x = self.pool_max02(x)x = self.layer31(x)x = self.layer32(x)x = self.pool_max02(x)x = self.pool_avg(x).view(x.size()[0], -1)x = self.fc(x)return x

由于CIFAR-100圖像維度為(3,32,32),適當修改了ResNet-18的設計框架加以應用。

2.正式訓練

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import time
from MyModel import BasicRes, cifar100total_start = time.time()# 正式訓練函數
def train_val(train_loader, val_loader, device, model, loss, optimizer, epochs, save_path):  # 正式訓練函數model = model.to(device)plt_train_loss = []  # 訓練過程loss值,存儲每輪訓練的均值plt_train_acc = []  # 訓練過程acc值plt_val_loss = []  # 驗證過程plt_val_acc = []max_acc = 0  # 以最大準確率來確定訓練過程的最優模型for epoch in range(epochs):  # 開始訓練train_loss = 0.0train_acc = 0.0val_acc = 0.0val_loss = 0.0start_time = time.time()model.train()for index, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()  # 梯度置0pred = model(images)bat_loss = loss(pred, labels)  # CrossEntropyLoss會對輸入進行一次softmaxbat_loss.backward()  # 回傳梯度optimizer.step()  # 更新模型參數train_loss += bat_loss.item()# 注意此時的pred結果為64*10的張量pred = pred.argmax(dim=1)train_acc += (pred == labels).sum().item()print("當前為第{}輪訓練,批次為{}/{},該批次總loss:{} | 正確acc數量:{}".format(epoch+1, index+1, len(train_data)//config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))# 計算當前Epoch的訓練損失和準確率,并存儲到對應列表中:plt_train_loss.append(train_loss / train_loader.dataset.__len__())plt_train_acc.append(train_acc / train_loader.dataset.__len__())model.eval()  # 模型調為驗證模式with torch.no_grad():  # 驗證過程不需要梯度回傳,無需追蹤gradfor index, (images, labels) in enumerate(val_loader):images, labels = images.cuda(), labels.cuda()pred = model(images)bat_loss = loss(pred, labels)  # 算交叉熵lossval_loss += bat_loss.item()pred = pred.argmax(dim=1)val_acc += (pred == labels).sum().item()print("當前為第{}輪驗證,批次為{}/{},該批次總loss:{} | 正確acc數量:{}".format(epoch+1, index+1, len(val_data)//config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))val_acc = val_acc / val_loader.dataset.__len__()if val_acc > max_acc:max_acc = val_acctorch.save(model, save_path)plt_val_loss.append(val_loss / val_loader.dataset.__len__())plt_val_acc.append(val_acc)print('該輪訓練結束,訓練結果如下[%03d/%03d] %2.2fsec(s) TrainAcc:%3.6f TrainLoss:%3.6f | valAcc:%3.6f valLoss:%3.6f \n\n'% (epoch+1, epochs, time.time()-start_time, plt_train_acc[-1], plt_train_loss[-1], plt_val_acc[-1], plt_val_loss[-1]))print(f'訓練結束,最佳模型的準確率為{max_acc}')plt.plot(plt_train_loss)  # 畫圖plt.plot(plt_val_loss)plt.title('loss')plt.legend(['train', 'val'])plt.show()plt.plot(plt_train_acc)plt.plot(plt_val_acc)plt.title('Accuracy')plt.legend(['train', 'val'])# plt.savefig('./acc.png')plt.show()# 1.數據預處理
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 以 50% 的概率隨機翻轉輸入的圖像,增強模型的泛化能力transforms.RandomCrop(size=(32, 32), padding=4),  # 隨機裁剪transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 對圖像張量進行歸一化
])  # 數據增強
ori_data = dataset.CIFAR100(root="./Data_CIFAR100",train=True,transform=transform,download=True
)
print(f"各標簽的真實含義:{ori_data.class_to_idx}\n")
# print(len(ori_data))
# # 查看某一樣本數據
# image, label = ori_data[0]
# print(f"Image shape: {image.shape}, Label: {label}")
# image = image.permute(1, 2, 0).numpy()
# plt.imshow(image)
# plt.title(f'Label: {label}')
# plt.show()config = {"train_size_perc": 0.8,"batch_size": 64,"learning_rate": 0.001,"epochs": 50,"save_path": "model_save/Res_cifar100_model.pth"
}# 設置訓練集和驗證集的比例
train_size = int(config["train_size_perc"] * len(ori_data))  # 80%用于訓練
val_size = len(ori_data) - train_size  # 20%用于驗證
train_data, val_data = random_split(ori_data, [train_size, val_size])
# print(len(train_data))
# print(len(val_data))train_loader = DataLoader(dataset=train_data, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=config["batch_size"], shuffle=False)device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device}\n")
model = cifar100()
# model = torch.load(config["save_path"]).to(device)
print(f"我的模型框架如下:\n{model}")
loss = nn.CrossEntropyLoss()  # 交叉熵損失函數
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)  # L2正則化
# optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])  # 優化器train_val(train_loader, val_loader, device, model, loss, optimizer, config["epochs"], config["save_path"])print(f"\n本次訓練總耗時為:{(time.time()-total_start) / 60 }min")

3.測試文件

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
from MyModel import BasicRes, cifar100total_start = time.time()
# 測試函數
def test(save_path, test_loader, device, loss):  # 測試函數best_model = torch.load(save_path).to(device)test_loss = 0.0test_acc = 0.0start_time = time.time()with torch.no_grad():for index, (images, labels) in enumerate(test_loader):images, labels = images.cuda(), labels.cuda()pred = best_model(images)bat_loss = loss(pred, labels)  # 算交叉熵losstest_loss += bat_loss.item()pred = pred.argmax(dim=1)test_acc += (pred == labels).sum().item()print("正在最終測試:批次為{}/{},該批次總loss:{} | 正確acc數量:{}".format(index + 1, len(test_data) // config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))print('最終測試結束,測試結果如下:%2.2fsec(s) TestAcc:%.2f%%  TestLoss:%.2f \n\n'% (time.time() - start_time, test_acc/test_loader.dataset.__len__()*100, test_loss/test_loader.dataset.__len__()))# 1.數據預處理
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 以 50% 的概率隨機翻轉輸入的圖像,增強模型的泛化能力transforms.RandomCrop(size=(32, 32), padding=4),  # 隨機裁剪transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 對圖像張量進行歸一化
])  # 數據增強
test_data = dataset.CIFAR100(root="./Data_CIFAR100",train=False,transform=transform,download=True
)
# print(len(test_data))  # torch.Size([3, 32, 32])
config = {"batch_size": 64,"save_path": "model_save/Res_cifar100_model.pth"
}
test_loader = DataLoader(dataset=test_data, batch_size=config["batch_size"], shuffle=True)
loss = nn.CrossEntropyLoss()  # 交叉熵損失函數
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device}\n")test(config["save_path"], test_loader, device, loss)print(f"\n本次訓練總耗時為:{time.time()-total_start}sec(s)")

4.訓練結果

設計learning rate=0.001先訓練了30輪,模型在測試集上的準確率已經來到了62.61%;
?

后續引入學習率衰減策略對同一網絡進行再次訓練,初始lr=0.001,衰減系數0.2,每20輪衰減一次,訓練60輪,結果如下:

最終訓練模型在測試集上的準確率 達到了65.20%

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/79406.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/79406.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/79406.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

愛勝品ICSP YPS-1133DN Plus黑白激光打印機報“自動進紙盒進紙失敗”處理方法之一

故障現象如下圖提示: 用戶的愛勝品ICSP YPS-1133DN Plus黑白激光打印機在工作過程中提示自動進紙盒進紙失敗并且紅色故障燈閃爍; 給出常見故障一般處理建議如下: 當您的愛勝品ICSP YPS-1133DN Plus 黑白激光打印機出現“自動進紙盒進紙失敗”…

Flinkcdc 實現 MySQL 寫入 Doris

Flinkcdc 實現 MySQL 寫入 Doris Flinkcdc 實現 MySQL 寫入 Doris 一、環境配置 Doris:3.0.4 JDK 17 MySQL (業務數據庫):5.7 MySQL(本地數據庫):5.7 Flink:flink-1.19.1 flinkc…

【Linux庖丁解牛】—環境變量!

目錄 1. 環境變量 1.1 概念介紹 1.2 命令行參數 1.3 一個例子,一個環境變量 1.4 認識更多的環境變量 1.5 獲取環境變量的方法 a. 指令操作 b. 代碼操作 1.6 理解環境變量的特性 a.環境變量具有全局特性 b.補充兩個概念(為后面埋一個伏筆) 1. 環境變量 …

LangChain4j +DeepSeek大模型應用開發——7 項目實戰 創建硅谷小鹿

這部分我們實現硅谷小鹿的基本聊天功能,包含聊天記憶、聊天記憶持久化、提示詞 1. 創建硅谷小鹿 創建XiaoLuAgent package com.ai.langchain4j.assistant;import dev.langchain4j.service.*; import dev.langchain4j.service.spring.AiService;import static dev…

普通 html 項目也可以支持 scss_sass

項目結構示例 下載vscode的插件Live Sass Compiler 自動監聽編譯scss 下載插件Live Server 用于 web 服務器,打開 html 文件到瀏覽器,也可以不用這個,自己用 nginx 或者寶塔其他 web 工具 新建一個 index.scss打開,點擊 vscode 底…

網工_IP協議

2025.02.17:小猿網&網工老姜學習筆記 第19節 IP協議 9.1 IP數據包的格式(首部數據部分)9.1.1 IP協議的首部格式(固定部分可變部分) 9.2 IP數據包分片(找題練)9.3 TTL生存時間的應用9.4 常見…

SQL語句練習 自學SQL網 在查詢中使用表達式 統計

目錄 Day 9 在查詢中使用表達式 Day 10 在查詢中進行統計 聚合函數 Day 11 在查詢中進行統計 HAVING關鍵字 Day12 查詢執行順序 Day 9 在查詢中使用表達式 SELECT id , Title , (International_salesDomestic_sales)/1000000 AS International_sales FROM moviesLEFT JOIN …

基于機器學習的輿情分析算法研究

標題:基于機器學習的輿情分析算法研究 內容:1.摘要 隨著互聯網的飛速發展,輿情信息呈現爆炸式增長,如何快速準確地分析輿情成為重要課題。本文旨在研究基于機器學習的輿情分析算法,以提高輿情分析的效率和準確性。方法上,收集了近…

菲索旋轉齒輪法:首次地面光速測量的科學魔術

一、當齒輪邂逅光束:19世紀的光速實驗室 1849年,法國物理學家阿曼德菲索(Armand Fizeau)在巴黎郊外的一座莊園里,用一組旋轉齒輪、一面鏡子和一盞油燈,完成了人類首次地面光速測量。他的實驗測得光速為315…

上位機知識篇---PSRAM和RAM

文章目錄 前言一、RAM(Random Access Memory)1. 核心定義分類:SRAM(靜態RAM)DRAM(動態RAM) 2. 關鍵特性SRAM優點缺點應用 DRAM優點缺點應用 3. 技術演進DDR SDRAMLPDDR(低功耗DRAM&a…

Qt QComboBox 下拉復選多選(multicombobox)

Qt QComboBox 下拉復選多選(multicombobox),備忘,待更多測試 【免費】QtQComboBox下拉復選多選(multicombobox)資源-CSDN文庫

ElasticSearch深入解析(五):如何將一臺電腦上的Elasticsearch服務遷移到另一臺電腦上

文章目錄 0.安裝數據遷移工具1.導出數據2.導出mapping3.導出查詢模板4.拷貝插件5.拷貝配置6.導入到目標電腦上 0.安裝數據遷移工具 Elasticsearch dump是一個用于將Elasticsearch索引數據導出為JSON格式的工具。你可以使用Elasticsearch dump通過命令行或編程接口來導出數據。…

微服務中組件掃描(ComponentScan)的工作原理

微服務中組件掃描(ComponentScan)的工作原理 你的問題涉及到Spring框架中ComponentScan的工作原理以及Maven依賴管理的影響。我來解釋為什么能夠掃描到common模塊的bean而掃描不到其他模塊的bean。 根本原因 關鍵在于**類路徑(Classpath)**的包含情況: Maven依賴…

Python鏡像源配置:

1.用命令進行配置: 1. 使用命令行方式更改鏡像源 可以直接通過 pip config 命令來設置全局或用戶級別的鏡像源地址。例如,使用清華大學開源軟件鏡像站作為新的索引 URL: pip config set global.index-url https://pypi.tuna.tsinghua.edu.…

【SpringBoot】Spring中事務的實現:聲明式事務@Transactional、編程式事務

1. 準備工作 1.1 在MySQL數據庫中創建相應的表 用戶注冊的例子進行演示事務操作,索引需要一個用戶信息表 (1)創建數據庫 -- 創建數據庫 DROP DATABASE IF EXISTS trans_test; CREATE DATABASE trans_test DEFAULT CHARACTER SET utf8mb4;…

javascript 深拷貝和淺拷貝的區別及具體實現方案

一、核心區別 特性淺拷貝深拷貝復制層級僅復制對象的第一層屬性遞歸復制對象的所有層級屬性(包括嵌套對象和數組)引用關系嵌套對象/數組與原對象共享內存(引用拷貝)嵌套對象/數組與原對象完全獨立(值拷貝)…

pytorch對應gpu版本是否可用判斷邏輯

# gpu_is_ok.py import torchdef check_torch_gpu():# 打印PyTorch版本print(f"PyTorch version: {torch.__version__}")# 檢查CUDA是否可用cuda_available torch.cuda.is_available()print(f"CUDA available: {cuda_available}")if cuda_available:# 打印…

國內無法訪問GitHub官網的問題解決

作為一名程序員,在國內訪問GitHub官網經常會遇到打開過慢或者訪問失敗的問題,但通過一些技巧可以改善訪問體驗。GitHub訪問問題的根源在于GitHub官網訪問不穩定的主要原因在于DNS解析過程。當我們直接訪問github.com時,需要通過DNS服務器將域…

使用 MediaPipe 和 OpenCV 快速生成人臉掩膜(Face Mask)

在實際項目中,尤其是涉及人臉識別、換臉、圖像修復等任務時,我們經常需要生成人臉區域的掩膜(mask)。這篇文章分享一個簡單易用的小工具,利用 MediaPipe 和 OpenCV,快速提取人臉輪廓并生成二值掩膜圖像。 …

【動態導通電阻】GaN功率器件中動態導通電阻退化的機制、表征及建模方法

2019年,浙江大學的Shu Yang等人在《IEEE Journal of Emerging and Selected Topics in Power Electronics》上發表了一篇關于GaN(氮化鎵)功率器件動態導通電阻(Dynamic On-Resistance, RON)的研究論文。該文深入探討了GaN功率器件中動態導通電阻退化的機制、表征方法、建模…