17.使用DenseNet網絡進行Fashion-Mnist分類

17.1 DenseNet網絡結構設計

在這里插入圖片描述

import torch
from torch import nn
from torchsummary import summary
#卷積層
def conv_block(input_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1))return net
#過渡層
def transition_block(inputs_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(inputs_channels),nn.ReLU(),nn.Conv2d(inputs_channels,num_channels,kernel_size=1),nn.AvgPool2d(kernel_size=2,stride=2))return net
#DenseNetBlock
class DenseBlock(nn.Module):def __init__(self, num_convs,input_channels,num_channels):super(DenseBlock,self).__init__()layer=[]for i in range(num_convs):layer.append(conv_block(num_channels*i+input_channels,num_channels))self.net=nn.Sequential(*layer)def forward(self,X):for blk in self.net:Y=blk(X)X=torch.cat((X,Y),dim=1)return X
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
num_channels,growth_rate=64,32
num_convs_in_dense_block=[4,4,4,4]
blks=[]
for i,num_convs in enumerate(num_convs_in_dense_block):blks.append(DenseBlock(num_convs,num_channels,growth_rate))# 上一個稠密塊的輸出通道數num_channels+=num_convs*growth_rate# 在稠密塊之間添加一個轉換層,使通道數量減半if i!=len(num_convs_in_dense_block)-1:blks.append(transition_block(num_channels,num_channels//2))num_channels=num_channels//2
model=nn.Sequential(b1,*blks,nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(num_channels,10))
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model.to(device)
summary(model,input_size=(1,224,224),batch_size=64)

在這里插入圖片描述

17.2 DenseNet網絡實現Fashion-Mnist分類

################################################################################################################
#DenseNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#DenseNet
################################################################################################################
#卷積層
def conv_block(input_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(input_channels),nn.ReLU(),nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1))return net
#過渡層
def transition_block(inputs_channels,num_channels):net=nn.Sequential(nn.BatchNorm2d(inputs_channels),nn.ReLU(),nn.Conv2d(inputs_channels,num_channels,kernel_size=1),nn.AvgPool2d(kernel_size=2,stride=2))return net
#DenseNetBlock
class DenseBlock(nn.Module):def __init__(self, num_convs,input_channels,num_channels):super(DenseBlock,self).__init__()layer=[]for i in range(num_convs):layer.append(conv_block(num_channels*i+input_channels,num_channels))self.net=nn.Sequential(*layer)def forward(self,X):for blk in self.net:Y=blk(X)X=torch.cat((X,Y),dim=1)return X
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
num_channels,growth_rate=64,32
num_convs_in_dense_block=[4,4,4,4]
blks=[]
for i,num_convs in enumerate(num_convs_in_dense_block):blks.append(DenseBlock(num_convs,num_channels,growth_rate))# 上一個稠密塊的輸出通道數num_channels+=num_convs*growth_rate# 在稠密塊之間添加一個轉換層,使通道數量減半if i!=len(num_convs_in_dense_block)-1:blks.append(transition_block(num_channels,num_channels//2))num_channels=num_channels//2
################################################################################################################
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一個是mean,第二個是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,*blks,nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(num_channels,10))
model.to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=10)
################################################################################################################

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/88874.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/88874.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/88874.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

網安系列【16】之Weblogic和jboss漏洞

文章目錄一 Weblogic1.1 Weblogic相關漏洞1.2 Weblogic漏洞發現1.3 Weblogic漏洞利用二 Jboss2.1 Jboss漏洞2.2 Jboss識別與漏洞利用一 Weblogic WebLogic 是由 Oracle公司 開發的一款基于Java EE(現稱Jakarta EE)的企業級應用服務器,主要用…

Unity URP + XR 自定義 Skybox 在真機變黑問題全解析與解決方案(支持 Pico、Quest 等一體機)

在使用 Unity 的 URP 渲染管線開發 XR 應用(如 Pico Neo、Pico 4、Quest 2/3 等一體機)時,很多開發者遇到一個奇怪的問題:打包后,Skybox(天空盒)在某些角度下突然變黑,只在轉動頭部后…

Cursor、飛算JavaAI、GitHub Copilot、Gemini CLI 等熱門 AI 開發工具合集

Cursor:代碼編寫的智能伙伴?Cursor 是 Anysphere 公司推出的一款 AI 編程工具,它基于微軟開源代碼編輯器 VS Code 開發,將 AI 技術深度整合到開發人員的工作流程中。Cursor 的功能十分強大,不僅能夠自動用純英文編寫代碼&#xf…

如何安裝歷史版本或指定版本的 git

背景 有的時候,我們需要安裝指定版本的git,或者希望舊一點的,畢竟我就遇到最新的2.50.1在win10安裝后打開就一閃而過,而安裝2.49.1就不會 下載 官網可能比較難找,但是這個github倉庫:https://github.com/gi…

LaCo: Large Language Model Pruning via Layer Collapse

發表:EMNLP_FINDING_2024 機構:Shanghai Jiao Tong University 連接:LaCo: Large Language Model Pruning via Layer Collapse - ACL Anthology 代碼:https://github.com/yangyifei729/LaCo Abstract 基于 Transformer 的大語…

服務器內核級故障排查

目錄 **檢查內核級故障(Oops/Panic)的具體操作步驟****1. 查看完整 `dmesg` 日志(含時間戳)****2. 過濾關鍵錯誤信息****3. 檢查系統日志中的內核消息****4. 分析最近一次啟動的日志****5. 檢查是否有 `vmcore` 轉儲文件****常見內核錯誤示例及含義**補充說明:檢查內核級故…

Flink學習筆記:整體架構

開一個新坑,系統性的學習下 Flink,計劃從整體架構到核心概念再到調優方法,最后是相關源碼的閱讀。 今天就來學習 Flink 整體架構,我們先看官網的架構圖圖中包含三部分,分別是 Client、JobManager 和 TaskManager。其中…

【LeetCode 熱題 100】105. 從前序與中序遍歷序列構造二叉樹——(解法二)O(n)

Problem: 105. 從前序與中序遍歷序列構造二叉樹 給定兩個整數數組 preorder 和 inorder ,其中 preorder 是二叉樹的先序遍歷, inorder 是同一棵樹的中序遍歷,請構造二叉樹并返回其根節點。 【LeetCode 熱題 100】105. 從前序與中序遍歷序列構…

完美卸載 Ubuntu 雙系統:從規劃到實施的完整指南

📖 前言 最近成功完成了一次 Ubuntu 雙系統的完整卸載,從最初的分區刪除到最終解決 GRUB 引導問題,整個過程雖然有些曲折,但最終完美解決。本文將詳細分享整個卸載過程,希望能幫助到有類似需求的朋友。 &#x1f3af…

深入理解oracle ADG和RAC

1. 引言 本節詳細介紹oracle ADG和RAC。當然這里講得的詳細是相對理論的深入,不涉及到實驗,比如ADG和RAC的搭建及調優等。 RAC (Real Application Clusters) 和 ADG (Active Data Guard)是Oracle 的兩大核心高可用和災備技術。它們是 Oracle 數據庫高可用…

網絡安全實踐:從環境搭建到漏洞復現

要求:1.搭建docker2.使用小皮面板搭建pikachu靶場3.使用BP的爆破模塊破解pikachu的登陸密碼步驟4.Kail的msf復現永恒之藍一.搭建docker1. Docker介紹Docker 是容器,可以部分完全封閉。封閉意味:一個物質(放到容器)&…

車載診斷架構 --- 診斷功能開發流程

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 做到欲望極簡,了解自己的真實欲望,不受外在潮流的影響,不盲從,不跟風。把自己的精力全部用在自己。一是去掉多余,凡事找規律,基礎是誠信;二是…

mysql數據庫知識

MySQL數據庫詳解MySQL是目前全球最流行的關系型數據庫管理系統之一,以其開源免費、高效穩定、易于擴展等特點,被廣泛應用于Web開發、企業級應用等場景。本文將從基礎概念、核心特性到實際應用,對MySQL進行全面解析。一、MySQL的基本概念1. 關…

基于springboot的美食文化和旅游推廣系統

博主介紹:java高級開發,從事互聯網行業多年,熟悉各種主流語言,精通java、python、php、爬蟲、web開發,已經做了多年的畢業設計程序開發,開發過上千套畢業設計程序,沒有什么華麗的語言&#xff0…

Rust賦能文心大模型4.5智能開發

文心大模型4.5版本概論 文心大模型4.5是百度推出的最新一代大規模預訓練語言模型,屬于文心大模型(ERNIE)系列。該模型在自然語言處理(NLP)、多模態理解與生成等領域表現出色,廣泛應用于智能搜索、內容創作、對話交互等場景。 核心能力 語言理解與生成 支持復雜語義理解…

前端抓包(不啟動前端項目就能進行后端調試)--whistle

1、安裝 1.1.安裝node.js 1.2.安裝whistle npm install -g whistle2.安裝瀏覽器插件【SwitchyOmega】在谷歌瀏覽器應用商店下載安裝即可配置proxy127.0.0.1:8989是w2 start的端口號啟用代理3.啟動服務(每次抓包都得啟動) w2 start點擊鏈接訪問網頁 http:…

kettle從入門到精通 第102課 ETL之kettle xxl-job調度kettle的兩種方式

之前我們一起學習過xxl-job調度carte,采用的xxl-job執行器方式,不了解的可以查看《kettle從入門到精通 第六十一課 ETL之kettle 任務調度器,輕松使用xxl-job調用kettle中的job和trans 》 今天我們一起來學習下使用xxl-job直接使用http調用…

純前端 JavaScript 實現數據導出到 CSV 格式

日常開發中,數據導出到文件通常有兩種方式: 在后端處理,以文件流或者資源路徑的方式返回;后端返回數據,前端按需處理后再觸發瀏覽器的下載事件,已保存到本地文件。 這里介紹后者的一種零依賴的實現方式。…

香港理工大學實驗室定時預約

香港理工大學實驗室定時預約 文章目錄香港理工大學實驗室定時預約簡介接單價格軟件界面網站預約界面代碼對爬蟲、逆向感興趣的同學可以查看文章,一對一小班教學(系統理論和實戰教程)、提供接單兼職渠道:https://blog.csdn.net/weixin_35770067/article/d…

Spring AI 項目實戰(十七):Spring Boot + AI + 通義千問星辰航空智能機票預訂系統(附完整源碼)

系列文章 序號文章名稱1Spring AI 項目實戰(一):Spring AI 核心模塊入門2Spring AI 項目實戰(二):Spring Boot + AI + DeepSeek 深度實戰(附完整源碼)3Spring AI 項目實戰(三):Spring Boot + AI + DeepSeek 打造智能客服系統(附完整源碼)4