Fashion-MNIST LeNet訓練

前面使用線性神經網絡softmax 和 ?多層感知機進行圖像分類,本次我們使用LeNet 卷積神經網絡進行

訓練,期望能捕捉到圖像中的圖像結構信息,提高識別精度:

import torch
import torchvision
from torchvision import transforms
from torch.utils import data
import time
from torch import nn
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inline
from IPython import displaysize = lambda x, *args, **kwargs: x.numel(*args, **kwargs)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
argmax = lambda x, *args, **kwargs: x.argmax(*args, **kwargs)
astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)class Timer:"""記錄多次運行時間"""def __init__(self):"""Defined in :numref:`subsec_linear_model`"""self.times = []self.start()def start(self):"""啟動計時器"""self.tik = time.time()def stop(self):"""停止計時器并將時間記錄在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均時間"""return sum(self.times) / len(self.times)def sum(self):"""返回時間總和"""return sum(self.times)def cumsum(self):"""返回累計時間"""return np.array(self.times).cumsum().tolist()def accuracy(y_hat, y):"""計算預測正確的數量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = argmax(y_hat, axis=1)cmp = astype(y_hat, y.dtype) == yreturn float(reduce_sum(astype(cmp, y.dtype)).cpu())def evaluate_accuracy(net, data_iter, device=None):"""計算在指定數據集上模型的精度"""metric = Accumulator(2)  # 正確預測數、預測總數net.eval()with torch.no_grad():for X, y in data_iter:X, y = X.to(device), y.to(device)metric.add(accuracy(net(X), y), size(y))return metric[0] / metric[1]def evaluate_accuracy_gpu(net, data_iter, device=None):"""計算在指定數據集上模型的精度"""if isinstance(net, nn.Module):net.eval()if not device:device = next(iter(net.parameters())).devicemetric = Accumulator(2)  # 正確預測數、預測總數net.eval()with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]
def use_svg_display():"""使用svg格式在Jupyter中顯示繪圖Defined in :numref:`sec_calculus`"""backend_inline.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):"""設置matplotlib的圖表大小Defined in :numref:`sec_calculus`"""use_svg_display()plt.rcParams['figure.figsize'] = figsizedef set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""設置matplotlib的軸Defined in :numref:`sec_calculus`"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):"""繪制數據點Defined in :numref:`sec_calculus`"""if legend is None:legend = []set_figsize(figsize)axes = axes if axes else plt.gca()# 如果X有一個軸,輸出Truedef has_one_axis(X):return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list)and not hasattr(X[0], "__len__"))if has_one_axis(X):X = [X]if Y is None:X, Y = [[]] * len(X), Xelif has_one_axis(Y):Y = [Y]if len(X) != len(Y):X = X * len(Y)axes.cla()for x, y, fmt in zip(X, Y, fmts):if len(x):axes.plot(x, y, fmt)else:axes.plot(y, fmt)set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)class Animator:"""在動畫中繪制數據"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""Defined in :numref:`sec_softmax_scratch`"""# 增量地繪制多條線if legend is None:legend = []use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函數捕獲參數self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向圖表中添加多個數據點if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)class Accumulator:"""在n個變量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def get_dataloader_workers():return 4def load_data_fashion_mnist(batch_size, resize=None):"""下載Fashion-MNIST數據集,然后將其加載到內存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):def init_weight(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weight)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch',xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test_acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y),X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches //5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'epoch {epoch + 1}, train_l={train_l:.5f}, test_acc={test_acc:.5f}')print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')def try_gpu(): if torch.backends.mps.is_available():return torch.device("mps")elif torch.cuda.is_available():return torch.device("cuda")else:return torch.device("cpu")device = try_gpu()net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
)lr, num_epochs = 0.9, 10
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
timer = Timer()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
print(f'train takes {timer.stop():.2f} sec')

結果如下:

epoch 10, train_l=0.49363, test_acc=0.80840
loss 0.494, train acc 0.812, test acc 0.808
30582.1 examples/sec on mps
train takes 65.80 sec

可以看到其準確率并不比線性模型和多層感知機更高。如果想進一步提高準確率,需進一步調整LeNet的參數,如學習率,學習批次,訓練次數等,大家自己嘗試一下。經過測試,學習率越低,似乎效果更差一些。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/85283.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/85283.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/85283.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

EasyRTC嵌入式音視頻通信SDK助力1v1實時音視頻通話全場景應用

一、方案概述? 在數字化通信需求日益增長的今天,EasyRTC作為一款全平臺互通的實時視頻通話方案,實現了設備與平臺間的跨端連接。它支持微信小程序、APP、PC客戶端等多端協同,開發者通過該方案可快速搭建1v1實時音視頻通信系統,適…

查看make命令執行后涉及的預編譯宏定義的值

要查看 make 命令執行后涉及的預編譯宏定義(如 -D 定義的宏)及其值,可以采用以下方法: 1. 查看 Makefile 中的宏定義 直接檢查 Makefile 或相關構建腳本(如 configure、CMakeLists.txt),尋找 -…

【C/C++】面試常考題目

面試中最常考的數據結構與算法題,適合作為刷題的第一階段重點。 ? 分類 & 推薦題目列表(精選 70 道核心題) 一、數組 & 字符串(共 15 題) 題目類型LeetCode編號兩數之和哈希表#1盛最多水的容器雙指針#11三數…

【芯片學習】555

一、引腳作用 二、原理圖 三、等效原理圖 1.比較器 同相輸入端大于反相輸入端,輸出高電平,反之亦然 2.三極管 給它輸入高電平就可以導通 3.模擬電路部分 4.數字電路部分 這部分的核心是RS觸發器,R-reset代表0,set是置位代表1&am…

Linux《文件系統》

在之前的系統IO當中已經了解了“內存”級別的文件操作,了解了文件描述符、重定向、緩沖區等概念,在了解了這些的知識之后還封裝出了我們自己的libc庫。接下來在本篇當中將會將視角從內存轉向磁盤,研究文件在內存當中是如何進行存儲的&#xf…

Java-代碼段-http接口調用自身服務中的其他http接口(mock)-并建立socket連接發送和接收報文實例

最新版本更新 https://code.jiangjiesheng.cn/article/367?fromcsdn 推薦 《高并發 & 微服務 & 性能調優實戰案例100講 源碼下載》 1. controller入口 ApiOperation("模擬平臺端現場機socket交互過程,需要Authorization")PostMapping(path "/testS…

基于遞歸思想的系統架構圖自動化生成實踐

文章目錄 一、核心思想解析二、關鍵技術實現1. 動態布局算法2. 樣式規范集成3. MCP服務封裝三、典型應用場景四、最佳實踐建議五、擴展方向一、核心思想解析 本系統通過遞歸算法實現了Markdown層級結構到PPTX架構圖的自動轉換,其核心設計思想包含兩個維度: 數據結構遞歸:將…

Python包管理器 uv替代conda?

有人問:python的包管理器uv可以替代conda嗎? 搞數據和算法的把conda當寶貝,其他的場景能替代。 Python的包管理器有很多,pip是原配,uv是后起之秀,conda則主打數據科學。 uv替代pip似乎只是時間問題了,它…

使用pnpm、vite搭建Phaserjs的開發環境

首先,確保你已經安裝了 Node.js 和 npm。然后按照以下步驟操作: 一、使用pnpm初始化一個新的 Vite 項目 pnpm create vite 輸入名字 選擇模板,這里我選擇Vanilla,也可以選擇其他的比如vue 選擇語言 項目新建完成 二、安裝相關依賴 進入項…

JS逆向案例—喜馬拉雅xm-sign詳情頁爬取

JS逆向案例——喜馬拉雅xm-sign詳情頁爬取 聲明網站流程分析總結 聲明 本文章中所有內容僅供學習交流,抓包內容、敏感網址、數據接口均已做脫敏處理,嚴禁用于商業用途和非法用途,否則由此產生的一切后果均與作者無關,若有侵權&am…

姜老師的MBTI課程:MBTI是可以轉變的

我們先來看內向和外向這條軸,I和E內向和外向受先天遺傳因素的影響還是比較大的,因為它事關到了你的硬件,也就是大腦的模型。但是我們在大五人格的排雷避坑和這套課程里面都強調了一個觀點,內向和外向各有優勢,也各有不…

進程同步:生產者-消費者 題目

正確答案: 問題類型: 經典生產者 - 消費者問題 同時涉及同步和互斥。 同步:生產者與消費者通過信號量協調生產 / 消費節奏(如緩沖區滿時生產者等待,空時消費者等待)。互斥:對共享緩沖區的訪問需…

吳恩達MCP課程(1):chat_bot

原課程代碼是用Anthropic寫的,下面代碼是用OpenAI改寫的,模型則用阿里巴巴的模型做測試 .env 文件為: OPENAI_API_KEYsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx OPENAI_API_BASEhttps://dashscope.aliyuncs.com/compatible-mode…

Netty 實戰篇:手寫一個輕量級 RPC 框架原型

本文將基于前文實現的編解碼與心跳機制,構建一個簡單的 RPC 框架,包括請求封裝、響應解析、動態代理調用。為打造微服務通信基礎打下基礎。 一、什么是 RPC? RPC(Remote Procedure Call,遠程過程調用)允許…

邊緣計算新基建:iVX 輕量生成模塊的 ARM 架構突圍

一、引言 隨著工業 4.0 和物聯網的快速發展,邊緣計算作為連接云端與終端設備的關鍵技術,正成為推動數字化轉型的核心力量。在邊緣計算場景中,設備的實時性、低功耗和離線處理能力至關重要。ARM 架構憑借其低功耗、高能效的特點,成…

C# 基于 Windows 系統與 Visual Studio 2017 的 Messenger 消息傳遞機制詳解:發布-訂閱模式實現

🧑 博主簡介:CSDN博客專家、CSDN平臺優質創作者,高級開發工程師,數學專業,10年以上C/C, C#, Java等多種編程語言開發經驗,擁有高級工程師證書;擅長C/C、C#等開發語言,熟悉Java常用開…

js數據類型有哪些?它們有什么區別?

js數據類型共有8種,分別是undefined,null,boolean,number,string,Object,symbol,bigint symbol和bigint是es6中提出來的數據類型 symbol創建后獨一無二不可變的數據類型,它主要是為了解決出現全局變量沖突的問題 bigint 是一種數字類型的數據,它可以表示任意精度格式的整數,…

Vite打包優化實踐:從分包到性能提升

前言: ??????? 隨著前端應用功能的增加,項目的打包體積也會不斷膨脹,影響加載速度和用戶體驗。本文介紹了幾種常見的打包優化策略,通過Vite和相關插件,幫助減少項目體積、提升性能,優化加載速度。 rollup-plugi…

C++語法系列之模板進階

前言 本次會介紹一下非類型模板參數、模板的特化(特例化)和模板的可變參數&#xff0c;不是最開始學的模板 一、非類型模板參數 字面意思,比如&#xff1a; template<size_t N 10> 或者 template<class T,size_t N 10>比如&#xff1a;靜態棧就可以用到&#…

html5的響應式布局的方法示例詳解

以下是HTML5實現響應式布局的5種核心方法及代碼示例: 1. 媒體查詢(核心方案) /* 默認樣式(移動優先) */ .container {padding: 15px; }/* 中等屏幕(平板) */ @media (min-width: 768px) {.container {padding: 30px;max-width: 720px;} }/* 大屏幕(桌面) */ @media …