第P1周:Pytorch實現mnist手寫數字識別

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

目標

1. 實現pytorch環境配置
2. 實現mnist手寫數字識別
3. 自己寫幾個數字識別試試

具體實現

(一)環境

語言環境:Python 3.10
編 譯 器: PyCharm
框 架:

(二)具體步驟
**1.**配置Pytorch環境

打開官網PyTorch,Get started:
image.png
接下來是選擇安裝版本,最難的就是確定Compute Platform的版本,是否要使用GPU。所以先要確定CUDA的版本。
image.png
會發現,pytorch官網根本沒有對應12.7的版本,先安裝最新的試試唄,選擇12.4:
image.png
安裝命令:pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
image.png
image.png
安裝完成,我們建立python文件,輸入如下代碼:

import torch  
x = torch.rand(5, 3)  
print(x)  print(torch.cuda.is_available())---------output---------------
tensor([[0.3952, 0.6351, 0.3107],[0.8780, 0.6469, 0.6714],[0.4380, 0.0236, 0.5976],[0.4132, 0.9663, 0.7576],[0.4047, 0.4636, 0.2858]])
True

從輸出來看,成功了。下面開始正式的mnist手寫數字識別

2. 下載數據并加載數據
import torch  
import torch.nn as nn  
# import matplotlib.pyplot as plt  
import torchvision  # 第一步:設置硬件設備,有GPU就使用GPU,沒有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 第二步:導入數據  
# MNIST數據在torchvision.datasets中,自帶的,可以通過代碼在線下載數據。  
train_ds = torchvision.datasets.MNIST(root='./data',    # 下載的數據所存儲的本地目錄  train=True,       # True為訓練集,False為測試集  transform=torchvision.transforms.ToTensor(),  # 將下載的數據直接轉換成張量格式  download=True     # True直接在線下載,且下載到root指定的目錄中,注意已經下載了,第二次以后就不會再下載了  )  
test_ds = torchvision.datasets.MNIST(root='./data',  train=False,  transform=torchvision.transforms.ToTensor(),  download=True  )  # 第三步:加載數據  
# Pytorch使用torch.utils.data.DataLoader進行數據加載  
batch_size = 32  
train_dl = torch.utils.data.DataLoader(dataset=train_ds, # 要加載的數據集  batch_size=batch_size, # 批次的大小  shuffle=True,     # 每個epoch重新排列數據  # 以下的參數有默認值可以不寫  num_workers=0, # 用于加載的子進程數,默認值為0.注意在windows中如果設置非0,有可能會報錯  pin_memory=True, # True-數據加載器將在返回之前將張量復制到設備/CUDA 固定內存中。 如果數據元素是自定義類型,或者collate_fn返回一個自定義類型的批次。  drop_last=False, #如果數據集大小不能被批次大小整除,則設置為 True 以刪除最后一個不完整的批次。 如果 False 并且數據集的大小不能被批大小整除,則最后一批將保留。 (默認值:False)  timeout=0, # 設置數據讀取的超時時間 , 超過這個時間還沒讀取到數據的話就會報錯。(默認值:0)  worker_init_fn=None # 如果不是 None,這將在步長之后和數據加載之前在每個工作子進程上調用,并使用工作 id([0,num_workers - 1] 中的一個 int)的順序逐個導入。(默認:None)  )  # 取一個批次看一下數據格式,數據的shape為[batch_size, channel, height, weight]  
# batch_size是已經設定的32,channel, height和weight分別是圖片的通道數,高度和寬度  
images, labels = next(iter(train_dl))  
print(images.shape)

image.png
image.png
看這個圖片的shape是torch.size([32, 1, 28, 28]),可以看圖MNIST的數據集里的圖像我猜應該是單色的(channel=1),28 * 28大小的圖片(height=28, weight=28)。
將圖片可視化展示出來看看:

# 數據可視化  
plt.figure(figsize=(20, 5)) # 指定圖片大小 ,圖像大小為20寬,高5的繪圖(單位為英寸)  
for i , images in enumerate(images[:20]):  # 維度縮減,npimg = np.squeeze(images.numpy())  # 將整個figure分成2行10列,繪制第i+1個子圖  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()

image.png

**3.**構建CNN網絡
num_classes = 10 # MNIST數據集中是識別0-9這10個數字,因此是10個類別。class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 特征提取網絡self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 第一層卷積,卷積核大小3*3self.pool1 = nn.MaxPool2d(2)    # 池化層,池化核大小為2*2self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 第二層卷積,卷積核大小3*3self.pool2 = nn.MaxPool2d(2)# 分類網絡self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 第四步:加載并打印模型
# 將模型轉移到GPU中
model = Model().to(device)
summary(model)>)

image.png

4.訓練模型
# 第五步:訓練模型  
loss_fn = nn.CrossEntropyLoss() # 創建損失函數  
learn_rate = 1e-2   # 設置學習率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)  # 循環訓練  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 訓練集的大小  num_batches = len(dataloader) # 批次數目  train_loss, train_acc = 0, 0  # 初始化訓練損失率和正確率都為0  for X, y in dataloader: # 獲取圖片及標簽  X, y = X.to(device), y.to(device)   # 將圖片和標準轉換到GPU中  # 計算預測誤差  pred = model(X) # 使用CNN網絡預測輸出pred  loss = loss_fn(pred, y) # 計算預測輸出的pred和真實值y之間的差距  # 反向傳播  optimizer.zero_grad()   # grad屬性歸零  loss.backward() # 反向傳播  optimizer.step()    # 第一步自動更新  # 記錄acc與loss  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 測試函數,注意測試函數不需要進行梯度下降,不進行網絡權重更新,所以不需要傳入優化器  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset)  num_batches = len(dataloader)  test_loss, test_acc = 0, 0  # 當不進行訓練時,停止梯度更新,節省計算內存消耗  with torch.no_grad():  for imgs, targets in dataloader:  imgs, target = imgs.to(device), targets.to(device)  # 計算 loss            target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式訓練  
epochs = 5  
train_loss, train_acc, test_loss, test_acc = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  train_acc.append(epoch_train_acc)  test_acc.append(epoch_test_acc)  train_loss.append(epoch_train_loss)  test_loss.append(epoch_test_loss)  template = 'Epoch: {:2d}, Train_acc:{:.1f}%, Train_loss: {:.3f}%, Test_acc: {:.1f}%, Test_loss: {:.3f}%'  print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  
print('Done')

image.png

# 可見化一下訓練結果  
warnings.filterwarnings("ignore")  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 顯示中文不標簽,不設置會顯示中文亂碼  
plt.rcParams['axes.unicode_minus'] = False      # 顯示負號  
plt.rcParams['figure.dpi'] = 100                # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  
plt.subplot(1, 2, 1)  plt.plot(epochs_range, train_acc, label='訓練正確率')  
plt.plot(epochs_range, test_acc, label='測試正確率')  
plt.legend(loc='lower right')  
plt.title('訓練與測試正確率')  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, train_loss, label='訓練損失率')  
plt.plot(epochs_range, test_loss, label='測試損失率')  
plt.legend(loc='upper right')  
plt.title('訓練與測試損失率')  plt.show()

image.png

四:預測一下自己手寫的數字

準備數據:
image.png
再手動將每個數字切割成單獨的一個文件:
image.png
注意,這里并沒有將每個圖片的大小切割成一致,理論上切割成要求的28*28是最好。我這里用代碼來重新生成28 * 28大小的圖片。

import torch  
import numpy as np  
from PIL import Image  
from torchvision import transforms  
import torch.nn as nn  
import torch.nn.functional as F  
import matplotlib.pyplot as plt  
import os, pathlib  # 第一步:設置硬件設備,有GPU就使用GPU,沒有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 定義模型,要把模型搞過來嘛,不然加載模型會出錯。  
class Model(nn.Module):  def __init__(self):  super().__init__()  # 特征提取網絡  self.conv1 = nn.Conv2d(1, 32, kernel_size=3 ) # 第一層卷積,卷積核大小3*3  self.pool1 = nn.MaxPool2d(2)    # 池化層,池化核大小為2*2  self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 第二層卷積,卷積核大小3*3  self.pool2 = nn.MaxPool2d(2)  # 分類網絡  self.fc1 = nn.Linear(1600, 64)  self.fc2 = nn.Linear(64, 10)  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = torch.flatten(x, start_dim=1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  # 加載模型  
model = torch.load('./models/cnn.pth')   
model.eval()  transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.1307,), (0.3081,))  
])  # 導入數據  
data_dir = "./mydata/handwrite"  
data_dir = pathlib.Path(data_dir)  
image_count = len(list(data_dir.glob('*.jpg')))  
print("圖片總數量為:", image_count)  plt.rcParams['font.sans-serif'] = ['SimHei']    # 顯示中文不標簽,不設置會顯示中文亂碼  
plt.rcParams['axes.unicode_minus'] = False      # 顯示負號  
plt.rcParams['figure.dpi'] = 100                # 分辨率  
plt.figure(figsize=(10, 10))  
i = 0  
for input_file in list(data_dir.glob('*.jpg')):  image = Image.open(input_file)  image_resize = image.resize((28, 28))   # 將圖片轉換成 28*28  image = image_resize.convert('L')  # 轉換成灰度圖  image_array = np.array(image)  # print(image_array.shape)    # (high, weight)  image = Image.fromarray(image_array)  image = transform(image)  image = torch.unsqueeze(image, 0)   # 返回維度為1的張量  image = image.to(device)  output = model(image)  pred = torch.argmax(output, dim=1)  image = torch.squeeze(image, 0)     # 返回一個張量,其中刪除了大小為1的輸入的所有指定維度  image = transforms.ToPILImage()(image)  plt.subplot(10, 4, i+1)  plt.tight_layout()  plt.imshow(image, cmap='gray', interpolation='none')  plt.title("實際值:{},預測值:{}".format(input_file.stem[:1], pred.item()))  plt.xticks([])  plt.yticks([])  i += 1  
plt.show()

image.png

準確性很低,40張圖片預測準確數量:6,占比:15.0%.。看圖片,感覺resize成28*28和轉換成灰度圖后,圖片本身已經失真比較嚴重了。先把圖片像素翻轉一下,其實就是反色處理,加上這段代碼:
image.png
image.png
準確率上了一個臺階(40張圖片預測準確數量:30,占比:75.0%).。但是看圖片,還是不清晰。

(三)總結
  1. epochs=5,預測的準確性達到97%,如果增加迭代的次數到10,準確性提升接近到99%。迭代20次則達到99.3,提升不明顯。
    image.png
    image.png
  2. batch_size如何從32調整到64,準確性差不太多
    image.png
    image.png
  3. 后續研究圖片增強

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/62232.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/62232.shtml
英文地址,請注明出處:http://en.pswp.cn/web/62232.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Seq2Seq模型的發展歷史;深層RNN結構為什么出現梯度消失/爆炸問題,Transformer為什么不會;Seq2Seq模型存在問題

目錄 Seq2Seq模型的發展歷史 改進不足的地方 深層RNN結構為什么出現梯度消失/爆炸問題,Transformer為什么不會 深層RNN結構為什么出現梯度消失/爆炸問題: Transformer為什么不會出現梯度消失/爆炸問題: Seq2Seq模型存在問題 T5模型介紹 Seq2Seq模型的發展歷史 序列到…

網絡安全技術詳解:虛擬專用網絡(VPN) 安全信息與事件管理(SIEM)

虛擬專用網絡(VPN)詳細介紹 虛擬專用網絡(VPN)通過在公共網絡上創建加密連接來保護數據傳輸的安全性和隱私性。 工作原理 VPN的工作原理涉及建立安全隧道和數據加密: 隧道協議:使用協議如PPTP、L2TP/IP…

Hive 窗口函數與分析函數深度解析:開啟大數據分析的新維度

Hive 窗口函數與分析函數深度解析:開啟大數據分析的新維度 在當今大數據蓬勃發展的時代,Hive 作為一款強大的數據倉庫工具,其窗口函數和分析函數猶如一把把精巧的手術刀,助力數據分析師們精準地剖析海量數據,挖掘出深…

SCAU期末筆記 - 數據庫系統概念

我校使用Database System Concepts,9-12章不考所以跳過,因為課都逃了所以復習很倉促,只準備過一下每一章最后的概念辨析,我也不知道有沒有用 第1章 引言 數據庫管理系統(DBMS) 由一個互相關聯的數據的集合…

Android 12系統源碼_窗口管理(九)深淺主題切換流程源碼分析

前言 上一篇我們簡單介紹了應用的窗口屬性WindowConfiguration這個類,該類存儲了當前窗口的顯示區域、屏幕的旋轉方向、窗口模式等參數,當設備屏幕發生旋轉的時候就是通過該類將具體的旋轉數據傳遞給應用的、而應用在加載資源文件的時候也會結合該類的A…

河南省的教育部科技查新工作站有哪些?

鄭州大學圖書館(Z12):2007年1月被批準設立“教育部綜合類科技查新工作站”,同年12月被河南省科技廳認定為河南省省級科技查新機構。主要面向河南省的高校、科研機構、企業提供科技查新、查收查引等服務。 河南大學圖書館&#xf…

Leetcode經典題6--買賣股票的最佳時機

買賣股票的最佳時機 題目描述: 給定一個數組 prices ,它的第 i 個元素 prices[i] 表示一支給定股票第 i 天的價格。 你只能選擇 某一天 買入這只股票,并選擇在 未來的某一個不同的日子 賣出該股票。設計一個算法來計算你所能獲取的最大利潤。…

MCPTT 與BTC

MCPTT(Mission Critical Push-to-Talk)和B-TrunC(寬帶集群)是兩種關鍵通信標準,它們分別由不同的組織制定和推廣。 MCPTT(Mission Critical Push-to-Talk)標準由3GPP(第三代合作伙伴…

去除賬號密碼自動賦值時的輸入框背景色

問題描述: 前端使用賬號密碼登錄,若在網頁保存過當前頁面的密碼和賬號,那么當再次進入該頁面,網頁會自動的把賬號和密碼賦到輸入框中,而此時輸入框是帶有背景色的,與周邊的白色背景顯得很不協調&#xff1…

【Pytorch】torch.reshape與torch.Tensor.reshape區別

問題引入: 在Pytorch文檔中,有torch.reshape與torch.Tensor.reshape兩個reshape操作,他們的區別是什么呢? 我們先來看一下官方文檔的定義: torch.reshape: torch.Tensor.reshape: 解釋: 在p…

掃碼與短信驗證碼登錄JS逆向分析與Python純算法還原

文章目錄 1. 寫在前面2. 掃碼接口分析2. 短信接口分析3. 加密算法還原【??作者主頁】:吳秋霖 【??作者介紹】:擅長爬蟲與JS加密逆向分析!Python領域優質創作者、CSDN博客專家、阿里云博客專家、華為云享專家。一路走來長期堅守并致力于Python與爬蟲領域研究與開發工作!…

spring6:3容器:IoC

spring6:3容器:IoC 目錄 spring6:3容器:IoC3、容器:IoC3.1、IoC容器3.1.1、控制反轉(IoC)3.1.2、依賴注入3.1.3、IoC容器在Spring的實現 3.2、基于XML管理Bean3.2.1、搭建子模塊spring6-ioc-xml…

【認證法規】安全隔離變壓器

文章目錄 定義反激電源變壓器 定義 安全隔離變壓器(safety isolating transformer),通過至少相當于雙重絕緣或加強絕緣的絕緣使輸入繞組與輸出繞組在電氣上分開的變壓器。這種變壓器是為以安全特低電壓向配電電路、電器或其它設備供電而設計…

車機端同步outlook日歷

最近在開發一個車機上的日歷助手,其中一個需求就是要實現手機端日歷和車機端日歷數據的同步。然而這種需求似乎沒辦法實現,畢竟手機日歷是手機廠商自己帶的系統應用,根本不能和車機端實現數據同步的。 那么只能去其他公共的平臺尋求一些機會&…

OpenCV-圖像閾值

簡單閾值法 此方法是直截了當的。如果像素值大于閾值,則會被賦為一個值(可能為白色),否則會賦為另一個值(可能為黑色)。使用的函數是 cv.threshold。第一個參數是源圖像,它應該是灰度圖像。第二…

力扣300.最長遞增子序列

題目描述 題目鏈接300. 最長遞增子序列 給你一個整數數組 nums ,找到其中最長嚴格遞增子序列的長度。 子序列 是由數組派生而來的序列,刪除(或不刪除)數組中的元素而不改變其余元素的順序。例如,[3,6,2,7] 是數組 […

Vue CLI的作用

Vue CLI(Command Line Interface)是一個基于Vue.js的官方腳手架工具,其主要作用是幫助開發者快速搭建Vue項目的基礎結構和開發環境。以下是Vue CLI的具體作用: 1、項目模板與快速生成 Vue CLI提供了一系列預設的項目模板&#x…

【藍橋杯每日一題】掃雷

掃雷 知識點 2024-12-3 藍橋杯每日一題 掃雷 dfs (bfs也是可行的) 題目大意 在一個二維平面上放置這N個炸雷,每個炸雷的信息有$(x_i,y_i,r_i) $,前兩個是坐標信息,第三個是爆炸半徑。然后會輸入M個排雷火箭&#xff0…

【大數據學習 | 面經】Spark 3.x 中的AQE(自適應查詢執行)

Spark 3.x 中的自適應查詢執行(Adaptive Query Execution,簡稱 AQE)通過多種方式提升性能,主要包括以下幾個方面: 動態合并 Shuffle 分區(Coalescing Post Shuffle Partitions): 當 …

城電科技 | 光伏景觀長廊 打造美麗鄉村綠色低碳示范區 光伏景觀設計方案

光伏景觀長廊是一種結合了光伏發電技術和零碳景觀設計的新型公共公共設施,光伏景觀長廊頂上的光伏板不僅可以為周邊用電設備提供清潔電能,而且還能作為遮陽設施使用,為人們提供一個美麗又實用的休閑娛樂空間。 光伏景觀長廊建設對打造美麗鄉…