手寫數字識別:使用PyTorch構建MNIST分類器
在這篇文章中,我將引導你通過使用PyTorch框架構建一個簡單的神經網絡模型,用于識別MNIST數據集中的手寫數字。MNIST數據集是一個經典的機器學習數據集,包含了60,000張訓練圖像和10,000張測試圖像,每張圖像都是28x28像素的灰度手寫數字。
環境準備
首先,確保你的環境中安裝了PyTorch和torchvision。可以通過以下命令安裝:
pip install torch torchvision
數據加載與預處理
我們首先加載MNIST數據集,并將圖像轉換為PyTorch張量格式,以便模型可以處理。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下載訓練數據集(包含訓練圖片+標簽)'''
training_data = datasets.MNIST( #跳轉到函數的內部源代碼,pycharm 按下ctrl+鼠標點擊 training_data:Datasetroot="data",#表示下載的手寫數字 到哪個路徑。60000train=True, #讀取下載后的數據 中的 訓練集download=True,#如果你之前已經下載過了,就不用再下載transform=ToTensor(), #張量,圖片是不能直接傳入神經網絡模型
) #對于pytorch庫能夠識別的數據一般是tensor張量。'''下載測試數據集(包含訓練圖片+標簽)'''
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
print(len(training_data))
數據可視化
為了更好地理解數據,我們可以展示一些手寫數字圖像。
''展示手寫字圖片,把訓練數據集中的前59000張圖片展示一下'''from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i+59000] #提取第59000張圖片figure.add_subplot(3, 3, i+1) #圖像窗口中創建多個小窗口,小窗口用于顯示圖片plt.title(label)plt.axis("off") # plt.show(I)#是示矢量,plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()
創建DataLoader
為了高效地加載數據,我們使用DataLoader
來批量加載數據。
# '"創建數據DataLoader(數據加載器)開'
# 'batch_size:將數據集分成多份,每一份為batch_size個數據'
# '優點:可以減少內存的使用,提高訓練速度。train_dataloader = DataLoader(training_data, batch_size=64) #64張圖片為一個包,train_dataloader:<torch
test_dataloader = DataLoader(test_data, batch_size=64)
模型定義
接下來,我們定義一個簡單的神經網絡模型,包含兩個隱藏層和一個輸出層。
'''定義神經網絡類的繼承這種方式'''
class NeuralNetwork(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.moduledef __init__(self): #python基礎關于類,self類自已本身super().__init__() #繼承的父類初始化self.flatten = nn.Flatten() #展開,創建一個展開對象flattenself.hidden1 = nn.Linear(28*28, 128 ) #第1個參數:有多少個神經元傳入進來,第2個參數:有多少個數據傳出self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x) #圖像進行展開x = self.hidden1(x)x = torch.relu(x) #激活函數,torch使用的relu函數 relu,tanhx = self.hidden2(x)x = torch.relu(x)x = self.out(x)return xmodel = NeuralNetwork().to(device) #把剛剛創建的模型傳入到Gpu
print(model)
訓練與測試
我們定義訓練和測試函數,使用交叉熵損失函數和隨機梯度下降優化器。
def train(dataloader, model, loss_fn, optimizer):model.train() #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w。在訓練過程中,w會被修改的
# #pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train()和 model.eval()。# 一般用法是:在訓練開始之前寫上model.trian(),在測試時寫上 model.eval()batch_size_num = 1for X, y in dataloader: #其中batch為每一個數據的編號X, y = X.to(device), y.to(device) #把訓練數據集和標簽傳入cpu或GPUpred = model.forward(X) #.forward可以被省略,父類中已經對次功能進行了設置。自動初始化wloss= loss_fn(pred, y) #通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad() #梯度值清零loss.backward() #反向傳播計算得到每個參數的梯度值woptimizer.step() #根據梯度更新網絡w參數loss_value = loss.item() #從tensor數據中提取數據出來,tensor獲取損失值if batch_size_num % 100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset) #10000num_batches = len(dataloader) #打包的數量model.eval() #測試,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad(): #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候。這for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item() #test_loss是會自動累加每一個批次的損失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值對應的索引號,dim=0表示每一列中的最大值b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能來衡量模型測試的好壞。correct /= size #平均的正確率print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")
訓練模型
最后,我們訓練模型并測試其性能。
loss_fn = nn.CrossEntropyLoss() #創建交叉熵損失函數對象,因為手寫字識別中一共有10個數字,輸出會有10個結果optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #創建一個優化器,SGD為隨機梯度下降算法
# #params:要訓練的參數,一般我們傳入的都是model.parameters()# #lr:learning_rate學習率,也就是步長#loss表示模型訓練后的輸出結果與,樣本標簽的差距。如果差距越小,就表示模型訓練越好,越逼近干真實的模型。# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
運行結果
結論
通過這篇文章,我們成功構建了一個簡單的神經網絡模型來識別MNIST數據集中的手寫數字。這個模型展示了如何使用PyTorch進行數據處理、模型定義、訓練和測試。希望這能幫助你開始自己的深度學習項目!