14.使用GoogleNet/Inception網絡進行Fashion-Mnist分類

14.1 GoogleNet網絡結構設計

在這里插入圖片描述
在這里插入圖片描述

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一條路線:1*1的卷積層self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二條路線:1*1的卷積層+3*3的卷積層self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三條路線:1*1的卷積層+5*5的卷積層self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四條路線:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一層p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#組建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
summary(model,input_size=(1,224,224),batch_size=1)

在這里插入圖片描述

14.2 GoogleNet網絡實現Fashion-Mnist分類

import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一條路線:1*1的卷積層self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二條路線:1*1的卷積層+3*3的卷積層self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三條路線:1*1的卷積層+5*5的卷積層self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四條路線:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一層p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#組建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一個是mean,第二個是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/88804.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/88804.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/88804.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

NE綜合實驗2:RIP 與 OSPF 動態路由精細配置、FTPTELNET 服務搭建及精準訪問限制

NE綜合實驗2:RIP 與 OSPF 動態路由精細配置、FTPTELNET 服務搭建及精準訪問限制 涉及的協議可以看我之前的文章: RIP實驗 OSPF協議:核心概念與配置要點解析 ACL協議:核心概念與配置要點解析 基于OSPF動態路由與ACL訪問控制的網…

Android 插件化實現原理詳解

Android 插件化實現原理詳解 插件化技術是Android開發中一項重要的高級技術,它允許應用動態加載和執行未安裝的APK模塊。以下是插件化技術的核心實現原理和關鍵技術點: 一、插件化核心思想宿主與插件: 宿主(Host):主應用APK&#…

空間智能-李飛飛團隊工作總結(至2025.07)

李飛飛團隊在空間智能(Spatial Intelligence)領域的研究自2024年起取得了一系列突破性進展,其里程碑成果可歸納為以下核心方向: 一、理論框架提出與定義(2024年) 1、空間智能概念系統化 a.定義: 李飛飛首次明確空間智能為“機器在3D空間和時間中感知、推理和行動的能…

【算法深練】BFS:“由近及遠”的遍歷藝術,廣度優先算法題型全解析

前言 寬度優先遍歷BFS與深度優先遍歷DFS有本質上的區別,DFS是一直擴到低之后找返回,而BFS是一層層的擴展就像剝洋蔥皮一樣。 通常BFS是將所有路徑同時進行嘗試,所以BFS找到的第一個滿足條件的位置,一定是路徑最短的位置&#xf…

ZW3D 二次開發-創建球體

使用中望3d用戶函數 cvxPartSphere 創建球體 函數定義: ZW_API_C evxErrors cvxPartSphere(svxSphereData *Sphere, int *idShape); typedef struct svxSphereData {evxBoolType Combine; /**<@brief combination method */svxPoint Center; /**<@brief sphere ce…

藝術總監的構圖“再造術”:用PS生成式AI,重塑照片敘事框架

在視覺敘事中&#xff0c;我們常常面臨一個核心的“對立統一”&#xff1a;一方面是**“被捕捉的瞬間”&#xff08;The Captured Moment&#xff09;&#xff0c;即攝影師在特定時間、特定地點所記錄下的客觀現實&#xff1b;另一方面是“被期望的敘事”**&#xff08;The Des…

ChatGPT無法登陸?分步排查指南與解決方案

ChatGPT作為全球領先的AI對話工具&#xff0c;日均處理超百萬次登錄請求&#xff0c;登陸問題可能導致用戶無法正常使用服務&#xff0c;影響工作效率或學習進度。 無論是顯示「網絡錯誤」「賬號未激活」&#xff0c;還是持續加載無響應&#xff0c;本文將從網絡連接、賬號狀態…

用Joern執行CPGQL找到C語言中不安全函數調用的流程

1. 引入 靜態應用程序安全測試&#xff08;Static application security testing&#xff09;簡稱SAST&#xff0c;是透過審查程式源代碼來識別漏洞&#xff0c;提升軟件安全性的作法。 Joern 是一個強大的開源靜態應用安全測試&#xff08;SAST&#xff09;工具&#xff0c;專…

讀文章 Critiques of World model

論文名稱&#xff1a;對世界模型的批判 作者單位&#xff1a; CMU&#xff0c; UC SD 原文鏈接&#xff1a;https://arxiv.org/pdf/2507.05169 摘要&#xff1a; 世界模型&#xff08;World Model&#xff09;——即真實世界環境的算法替代物&#xff0c;是生物體所體驗并與之…

利用docker部署前后端分離項目

后端部署數據庫:redis部署:拉取鏡像:doker pull redis運行容器:docker run -d -p 6379:6379 --name my_redis redismysql部署:拉取鏡像:docker pull mysql運行容器:我這里3306被占了就用的39001映射docker run -d -p 39001:3306 -v /home/mysql/conf:/etc/mysql/conf.d -v /hom…

YOLOv11調參指南

YOLOv11調參 1. YOLOv11參數體系概述 YOLOv11作為目標檢測領域的前沿算法&#xff0c;其參數體系可分為四大核心模塊&#xff1a; 模型結構參數&#xff1a;決定網絡深度、寬度、特征融合方式訓練參數&#xff1a;控制學習率、優化器、數據增強策略檢測參數&#xff1a;影響預測…

云原生核心技術解析:Docker vs Kubernetes vs Docker Compose

云原生核心技術解析&#xff1a;Docker vs Kubernetes vs Docker Compose &#x1f6a2;???? 一、云原生核心概念 ?? 云原生&#xff08;Cloud Native&#xff09; 是一種基于云計算模型構建和運行應用的方法論&#xff0c;核心目標是通過以下技術實現彈性、可擴展、高可…

keepalive模擬操作部署

目錄 keepalived雙機熱備 一、配置準備 二、配置雙機熱備&#xff08;基于nginx&#xff09; web1端 修改配置文件 配置腳本文件 web2端 修改配置文件 配置腳本文件 模擬檢測 開啟keepalived服務 訪問結果 故障模擬 中止nginx 查看IP 訪問瀏覽器 重啟服務后…

Java 中的 volatile 是什么?

&#x1f449; volatile &#xff1a;不穩定的 英[?v?l?ta?l] 美[?vɑ?l?tl] adj. 不穩定的;<計>易失的;易揮發的&#xff0c;易發散的;爆發性的&#xff0c;爆炸性的;易變的&#xff0c;無定性的&#xff0c;無常性的;短暫的&#xff0c;片刻的;活潑的&#xff…

MongoDB性能優化實戰指南:原理、實踐與案例

MongoDB性能優化實戰指南&#xff1a;原理、實踐與案例 在大規模數據存儲與查詢場景下&#xff0c;MongoDB憑借其靈活的文檔模型和水平擴展能力&#xff0c;成為眾多互聯網及企業級應用的首選。然而&#xff0c;在生產環境中&#xff0c;隨著數據量和并發的增長&#xff0c;如何…

細談kotlin中綴表達式

Kotlin 是一種適應你編程風格的語言&#xff0c;允許你在想什么時候寫代碼就什么時候寫代碼。Kotlin 提供了一些機制&#xff0c;幫助我們編寫易讀易懂的代碼。其中一個非常有趣的機制是 中綴表達式&#xff08;infix notation&#xff09;。它允許我們定義和調用函數時省略點號…

[Nagios Core] CGI接口 | 狀態數據管理.dat | 性能優化

鏈接&#xff1a;https://assets.nagios.com/downloads/nagioscore/docs/nagioscore/4/en/ docs&#xff1a;Nagios Core Nagios Core 是功能強大的基礎設施監控系統&#xff0c;包含 CGI 程序&#xff0c;允許用戶通過 Web 界面查看當前狀態、歷史記錄等。通過以下技術棧實現…

Linux進程優先級機制深度解析:從Nice值到實時調度

前言 在Linux系統中&#xff0c;進程優先級決定了CPU資源的分配順序&#xff0c;直接影響系統性能和關鍵任務的響應速度。無論是優化服務器負載、確保實時任務穩定運行&#xff0c;還是避免低優先級進程拖慢系統&#xff0c;合理調整進程優先級都是系統管理和性能調優的重要技能…

深入淺出Kafka Broker源碼解析(下篇):副本機制與控制器

一、副本機制深度解析 1.1 ISR機制實現 1.1.1 ISR管理核心邏輯 ISR&#xff08;In-Sync Replicas&#xff09;是Kafka保證數據一致性的核心機制&#xff0c;其實現主要分布在ReplicaManager和Partition類中&#xff1a; public class ReplicaManager {// ISR變更集合&#xff0…

Fluent許可文件安裝和配置

在使用Fluent軟件進行流體動力學模擬之前&#xff0c;正確安裝和配置Fluent許可文件是至關重要的一步。本文將為您提供詳細的Fluent許可文件安裝和配置指南&#xff0c;幫助您輕松完成許可文件的安裝和配置&#xff0c;確保Fluent軟件能夠順利運行。 一、Fluent許可文件安裝步驟…