pytorch | minist手寫數據集

一、神經網絡

神經網絡(Neural Network)是一種受生物神經系統(尤其是大腦神經元連接方式)啟發的機器學習模型,是深度學習的核心基礎。它通過模擬大量 “人工神經元” 的互聯結構,學習數據中的復雜模式和規律,從而實現分類、預測、生成等任務。

本項目采用的是全連接神經網絡(Fully Connected Network)

每一層的神經元與下一層所有神經元連接。

應用:簡單分類 / 回歸任務(如房價預測、鳶尾花分類)。

大概模型如下圖所示:

1.輸入層(Input Layer):?

  1. 接收原始數據(如圖像的像素值、文本的詞向量),不進行計算,僅傳遞數據。

  2. 神經元數量 = 輸入數據的維度(例如:28*28?的彩色圖像輸入層有 28*28=784個神經元)。

2.隱藏層(Hidden Layer)

  1. 位于輸入層和輸出層之間,負責提取數據的特征(如邊緣、紋理、語義等)。
  2. 共有2層,其中第一層設置了250個節點,第二層設置了200層。

3.輸出層(Output Layer)

  1. 輸出模型的最終結果(如分類任務的類別概率、回歸任務的預測值)。
  2. 神經元數量 = 任務目標的維度(10 類分類任務輸出層有 10 個神經元)。

4.激活函數

激活函數是神經網絡能擬合復雜模式的關鍵,其核心作用是引入非線性變換(否則多層網絡等價于單層線性模型)。本項目所使用的激活函數是relu函數:

\text{ReLU}(z) = \max(0, z)

5.計算損失(Loss Calculation)

  1. 損失函數(Loss Function)衡量預測結果與真實標簽的差距(例如:分類任務用交叉熵損失,回歸任務用均方誤差)。
  2. 損失越小,模型預測越準。

在 PyTorch 中,CrossEntropyLoss是用于分類任務的常用損失函數,尤其適用于多類別分類(也可用于二分類)。它的設計非常靈活,內部集成了log_softmaxNLLLoss的功能,簡化了模型輸出與損失計算的流程。

CrossEntropyLoss?=?NLLLoss(log_softmax(logits), targets)

NLLLoss(y, \hat{y})=-\frac{1}{N}ylog \hat{y}

log\, softmax=log\frac{e^{y_{i}}}{\sum {j} e^{y_{j}}}

6.正則化

用于防止模型過擬合,提高泛化能力。

L2 正則化(Ridge Regression):傾向于減小參數值

  • 參數平滑性:L2 正則化會使參數值變小,但不會完全為 0,從而使模型更加平滑。
  • 幾何解釋:L2 的約束區域是一個圓形,與損失函數的等高線相交時,參數更可能落在非零的位置。

在損失函數中添加參數的平方和作為懲罰項:損失函數= 原始損失?+\lambda \sum_{i} w_i^2

7.優化器

Adam(Adaptive Moment Estimation)是深度學習中最流行的優化算法之一,結合了 Adagrad 和 RMSProp 的優點,能夠自適應地調整每個參數的學習率。它在實踐中表現出色,廣泛應用于各種神經網絡訓練任務。

二、代碼

網絡模型

單獨新建一個Python文件用于存儲網絡模型,其中定義了兩個隱藏層,都使用全連接神經網絡,第一個隱藏層有250個節點,第二隱藏層有200個節點,從最后一個隱藏層到輸出層同樣使用全連接神經網絡

self.fc1=nn.Linear(28*28,250)
self.fc2=nn.Linear(250,200)
self.fc3=nn.Linear(200,10)

?前兩層(輸入層->第一個隱藏層->第二個隱藏層)都使用relu激活函數,第二個隱藏層到輸出層暫不使用激活函數,這樣做可以減少loss的計算誤差

x =F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)

?完整代碼如下:

class network1(nn.Module):def __init__(self):super(network1, self).__init__()self.fc1=nn.Linear(28*28,250)self.fc2=nn.Linear(250,200)self.fc3=nn.Linear(200,10)def forward(self, x):x =F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return x

訓練模型

工具模塊

import torchvisionfrom model import *
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

數據處理,將圖片數據轉換成tensor數據類型

trans=torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])

?導入數據,并獲取數據長度

train_dataset=torchvision.datasets.MNIST("./data",train=True,transform=trans,download=True)
test_dataset=torchvision.datasets.MNIST("./data",train=False,transform=trans,download=True)train_data_size=len(train_dataset)
test_data_size=len(test_dataset)

加載數據,并設置批量大小為64,設置?打亂數據順序,因為數據集大小不能被64整除,設置丟棄多余數據

train_dataloader=DataLoader(train_dataset,batch_size=64,shuffle=True,drop_last=True)
test_dataloader=DataLoader(test_dataset,batch_size=64,shuffle=True,drop_last=True)

聲明網絡模型變量,設置損失函數,使用adam優化,設置l2正則

model=network1()loss_fn=nn.CrossEntropyLoss()learning_rate=1e-2
l2_lambda=1e-5
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,weight_decay=l2_lambda)

?設置訓練過程中可能用到的變量

total_train_step=0
total_test_step=0
epoch=50

??創建SummaryWriter,指定日志保存目錄

writer=SummaryWriter("./mini_logs")

?在迭代器中取出數據和標簽,將數據展成一維的,使其符合網絡模型的輸入,將其放入模型中

    for data in train_dataloader:imgs,targets=dataimgs=torch.reshape(imgs,(64,-1))outputs=model(imgs)

計算損失值

loss = loss_fn(outputs, targets)
total_train_loss+=loss

?計算準確率

accuracy = ((outputs.argmax(1) == targets).sum())
total_train_accuracy += accuracy
total_train_step+=1

更新梯度,進行參數優化

optimizer.zero_grad()
loss.backward()
optimizer.step()

?輸出訓練數據,并將訓練誤差和準確率添加到tensorboard上

# 100輪輸出一次,防止數據太多        if total_train_step % 100 == 0:print('訓練次數:{},Loss:{}'.format(total_train_step, loss.item()))  print("整體訓練集上的Loss為:{}".format(total_train_loss))print("整體訓練集上的正確率為:{}".format(total_train_accuracy / train_data_size))writer.add_scalar("Loss/train_loss", total_train_loss, total_test_step)#使用test_step是為了對齊測試誤差,下同writer.add_scalar("Ac/train_accuracy", total_train_accuracy / train_data_size, total_test_step)

?在test數據集上測試參數效果,過程與訓練過程相似,需要注意的是,在測試過程中可以不計算梯度,加快計算效率

total_test_loss=0total_test_accuracy=0with torch.no_grad():for data in test_dataloader:imgs,targets=dataimgs = torch.reshape(imgs, (64, -1))outputs=model(imgs)loss=loss_fn(outputs,targets)total_test_loss+=lossaccuracy = ((outputs.argmax(1) == targets).sum())total_test_accuracy += accuracyprint("整體測試集上的Loss為:{}".format(total_test_loss))print("整體測試集上的正確率為:{}".format(total_test_accuracy / test_data_size))writer.add_scalar("Loss/test_loss", total_test_loss, total_test_step)writer.add_scalar("Ac/test_accuracy", total_test_accuracy / test_data_size, total_test_step)

最后保存模型并關閉SummaryWriter

torch.save(model,"minst_weight/net_{}.pth".format(i))print("模型已保存")writer.close()

?完整代碼如下:

import torchvisionfrom model import *
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoadertrans=torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])train_dataset=torchvision.datasets.MNIST("./data",train=True,transform=trans,download=True)
test_dataset=torchvision.datasets.MNIST("./data",train=False,transform=trans,download=True)train_data_size=len(train_dataset)
test_data_size=len(test_dataset)train_dataloader=DataLoader(train_dataset,batch_size=64,shuffle=True,drop_last=True)
test_dataloader=DataLoader(test_dataset,batch_size=64,shuffle=True,drop_last=True)model=network1()loss_fn=nn.CrossEntropyLoss()learning_rate=1e-2
l2_lambda=1e-5
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,weight_decay=l2_lambda)total_train_step=0
total_test_step=0
epoch=50writer=SummaryWriter("./mini_logs")for i in range(epoch):print('--------第{}輪訓練開始--------'.format(i))total_train_loss = 0total_train_accuracy=0for data in train_dataloader:imgs,targets=dataimgs=torch.reshape(imgs,(64,-1))outputs=model(imgs)loss = loss_fn(outputs, targets)total_train_loss+=lossaccuracy = ((outputs.argmax(1) == targets).sum())total_train_accuracy += accuracytotal_train_step+=1optimizer.zero_grad()loss.backward()optimizer.step()if total_train_step % 100 == 0:print('訓練次數:{},Loss:{}'.format(total_train_step, loss.item()))  # 100輪輸出一次,防止數據太多print("整體訓練集上的Loss為:{}".format(total_train_loss))print("整體訓練集上的正確率為:{}".format(total_train_accuracy / train_data_size))writer.add_scalar("Loss/train_loss", total_train_loss, total_test_step)#使用test_step是為了對齊測試誤差,下同writer.add_scalar("Ac/train_accuracy", total_train_accuracy / train_data_size, total_test_step)total_test_loss=0total_test_accuracy=0with torch.no_grad():for data in test_dataloader:imgs,targets=dataimgs = torch.reshape(imgs, (64, -1))outputs=model(imgs)loss=loss_fn(outputs,targets)total_test_loss+=lossaccuracy = ((outputs.argmax(1) == targets).sum())total_test_accuracy += accuracyprint("整體測試集上的Loss為:{}".format(total_test_loss))print("整體測試集上的正確率為:{}".format(total_test_accuracy / test_data_size))writer.add_scalar("Loss/test_loss", total_test_loss, total_test_step)writer.add_scalar("Ac/test_accuracy", total_test_accuracy / test_data_size, total_test_step)total_test_step += 1torch.save(model,"minst_weight/net_{}.pth".format(i))print("模型已保存")writer.close()

訓練結果如下圖所示:

可以看到在test數據集上,該模型的準確率還是非常高的

模型測試

在網頁上隨便截取幾張手寫數字,測試一下其效果,效果可能與test數據集測試的效果相差較大,是因為網上的數據與訓練的數據相差太大,所以這個實驗可能并沒有什么參考價值,可以練練手

數據描述:在網上截取了0-9的圖片,并命名為img_0這種格式

第一步:獲取數據

    image_path= 'imgs/img_{}.png'.format(i)image=Image.open(image_path)

第二步:數據預處理,將彩色圖轉為灰度圖,裁剪其尺寸為28*28,然后將其轉換成tensor數據類型

image=image.convert("L")trans=torchvision.transforms.Compose([torchvision.transforms.Resize((28,28)),torchvision.transforms.ToTensor()])image=trans(image)

第三步:可以查看一下處理完的數據,由于add_image要求數據是三維的,所以需要先處理一下數據

    image=torch.reshape(image,(1,28,-1))writer=SummaryWriter("mini_show")writer.add_image("{}".format(i),image)writer.close()

第四步:加載模型,并輸出訓練結果

    image=torch.reshape(image,(1,-1))model=torch.load('minst_weight/net_40.pth')model.eval()with torch.no_grad():output=model(image)print('{}預測的值是:{}'.format(i,output.argmax(1).item()))

完整代碼如下:?

import torch
import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriterfrom model import *
for i in range(10):image_path= 'imgs/img_{}.png'.format(i)image=Image.open(image_path)image=image.convert("L")trans=torchvision.transforms.Compose([torchvision.transforms.Resize((28,28)),torchvision.transforms.ToTensor()])image=trans(image)image=torch.reshape(image,(1,28,-1))writer=SummaryWriter("mini_show")writer.add_image("{}".format(i),image)writer.close()image=torch.reshape(image,(1,-1))model=torch.load('minst_weight/net_40.pth')model.eval()with torch.no_grad():output=model(image)print('{}預測的值是:{}'.format(i,output.argmax(1).item()))

截取并預處理后的圖片:?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/915378.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/915378.shtml
英文地址,請注明出處:http://en.pswp.cn/news/915378.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[C/C++安全編程]_[中級]_[如何避免出現野指針]

場景 在Rust里不會出現野指針的情況,那么在C里能避免嗎? 說明 野指針是指指向無效內存地址的指針,訪問它會導致未定義行為,可能引發程序崩潰、數據損壞或安全漏洞。它是 C/C 等手動內存管理語言中的常見錯誤,而 Rust…

機器學習基礎:從數據到智能的入門指南

一、何謂機器學習? 在我們的日常生活中,機器學習的身影無處不在。當你打開購物軟件,它總能精準推薦你可能喜歡的商品;當你解鎖手機,人臉識別瞬間完成;當你使用語音助手,它能準確理解你的指令。這些背后&a…

steam游戲搬磚項目超完整版實操分享

大家好,我是阿陽,今天再次最詳細的給大家綜合全面的分析講解下steam搬磚,可以點擊后面跳轉往期文章了再次解下阿陽網客:關于steam游戲搬磚項目,我想說!最早是21年5月份公開朋友圈,初次接觸是在2…

vue2 面試題及詳細答案150道(21 - 40)

《前后端面試題》專欄集合了前后端各個知識模塊的面試題,包括html,javascript,css,vue,react,java,Openlayers,leaflet,cesium,mapboxGL,threejs&…

原生前端JavaScript/CSS與現代框架(Vue、React)的聯系與區別(詳細版)

原生前端JavaScript/CSS與現代框架(Vue、React)的聯系與區別,以及運行環境和條件 目錄 引言原生前端技術概述 JavaScript基礎CSS基礎 現代框架概述 Vue.jsReact 聯系與相似性主要區別對比運行環境和條件選擇建議總結 引言 在現代Web開發中&…

基于機器視覺的邁克耳孫干涉環自動計數系統設計與實現

基于機器視覺的邁克耳孫干涉環自動計數系統設計與實現 前些天發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,忍不住分享一下給大家。點擊跳轉到網站。 摘要 本文設計并實現了一種基于機器視覺的邁克耳孫干涉環自動計數系統。該系統…

設計模式筆記(1)簡單工廠模式

最近在看程杰的《大話設計模式》,在這里做一點筆記。 書中主要有兩個角色: 小菜:初學者,學生; 大鳥:小菜表哥,大佬。 也按圖中的對話形式 01 簡單工廠模式 要求:使用c、Java、C#或VB…

Vue3 學習教程,從入門到精通,Vue 3 聲明式渲染語法指南(10)

Vue 3 聲明式渲染語法指南 本文將詳細介紹 Vue 3 中的聲明式渲染語法,涵蓋所有核心概念,并通過一個完整的案例代碼進行演示。案例代碼中包含詳細注釋,幫助初學者更好地理解每個部分的功能和用法。 目錄 簡介聲明式渲染基礎 文本插值屬性綁…

React hooks——useReducer

一、簡介useReducer 是 React 提供的一個高級 Hook,用于管理復雜的狀態邏輯。它類似于 Redux 中的 reducer 模式,適合處理包含多個子值、依賴前一個狀態或邏輯復雜的狀態更新場景。與 useState 相比,useReducer 提供更結構化的狀態管理方式。…

SEO中關于關鍵詞分類與布局的方法有那些

前邊我們說到關鍵詞挖掘肯定很重要,但如何把挖掘出來的關鍵詞用好更為重要,下邊我們就來說說很多seo剛入行的朋友比較頭疼的關鍵詞分類問題,為了更直觀的感受搭配了表格,希望可以給大家一些幫助!SEO優化之關鍵詞分類?挖掘出的關鍵…

考研最高效的準備工作是什么

從性價比的角度來說,考研最高效的準備工作是什么呢? 其實就是“卷成績”。 卷學校中各門課程的成績,卷考研必考的數學、英語、政治和專業課的成績。 因為現階段的考研,最看重的仍然是你的成績,特別是初試成績。 有了…

【Linux】基于Ollama和Streamlit快速部署聊天大模型

1.環境準備 1.1 安裝Streamlit 在安裝Streamlit之前,請確保您的系統中已經正確安裝了Python和pip。您可以在終端或命令行中運行以下命令來驗證它們是否已安裝 python --version pip --version一旦您已經準備好環境,現在可以使用pip來安裝Streamlit了。…

Jetpack - ViewModel、LiveData、DataBinding(數據綁定、雙向數據綁定)

一、ViewModel 1、基本介紹 ViewModel 屬于 Android Jetpack 架構組件的一部分,ViewModel 被設計用來存儲和管理與 UI 相關的數據,這些數據在配置更改(例如,屏幕旋轉)時能夠幸存下來,ViewModel 的生命周期與…

Go并發聊天室:從零構建實戰

大家好,今天我將分享一個使用Go語言從零開始構建的控制臺并發聊天室項目。這個項目雖然簡單,但它麻雀雖小五臟俱全,非常適合用來學習和實踐Go語言強大的并發特性,尤其是 goroutine 和 channel 的使用。 一、項目亮點與功能特性 …

瘋狂星期四第13天運營日報

網站運營第13天,點擊觀站: 瘋狂星期四 crazy-thursday.com 全網最全的瘋狂星期四文案網站 運營報告 昨日訪問量 昨天大概60個ip, 同比上個星期是高點的,但是與星期四差別還是太大了。😂 昨日搜索引擎收錄情況 百度依舊0收錄 …

吳恩達《AI for everyone》第二周課程筆記

機器學習項目工作流程以Echo/Alexa(語音識別AI)作為例子解釋: 1. collect data 收集數據——人為找很多人說 Alexa,并錄制音頻;并且還會讓一群人說其他詞語,比如hello 2. train model 訓練模型——用機器學…

uniapp props、$ref、$emit、$parent、$child、$on

1. uniapp props、ref、ref、ref、emit、parent、parent、parent、child、$on 1.1. 父組件和子組件 propsPage.vue導入props-son-view.vue組件的時候,我們就稱index.vue為父組件依次類推,在vue中只要能獲取到組件的實例,那么就可以調用組件的屬性或是方法進行操作 1.2. pr…

4、ubuntu | dify創建知識庫 | 上市公司個股研報知識庫

1、創建知識庫步驟 創建一個知識庫并上傳相關文檔主要涉及以下五個關鍵步驟: 創建知識庫:首先,需要創建一個新的知識庫。這可以通過上傳本地文件、從在線資源導入數據或者直接創建一個空的知識庫來實現。 指定分段模式:接下來是…

Kubernetes中為Elasticsearch配置多節點共享存儲

在Kubernetes中為Elasticsearch配置多節點共享存儲(ReadWriteMany)需結合存儲后端特性及Elasticsearch架構設計。 由于Elasticsearch默認要求每個節點獨立存儲數據(ReadWriteOnce),直接實現多節點共享存儲需特殊處理。 ??方案一:使用支持ReadWriteMany的存儲后端(推薦…

SpringBoot熱部署與配置技巧

配置文件SpringBoot 的熱部署Spring為開發者提供了一個名為spring-boot-devtools的模塊來使SpringBoot應用支持熱部署&#xff0c;提高開發者的開發效率&#xff0c;無需手動重啟SpringBoot應用相關依賴&#xff1a;<dependency> <groupId>org.springframework.boo…