目錄
一、什么是全連接網絡
二、代碼實現步驟
1. 導入必要的庫
2. 數據準備
3. 定義網絡結構
4. 模型訓練
5. 模型保存和加載
6. 預測單張圖片
7. 主函數
三、運行結果說明
四、小結
一、什么是全連接網絡
全連接神經網絡(Fully Connected Neural Network)是一種最基礎的神經網絡結構,其特點是每一層的每個神經元都與上一層的所有神經元相連。
打個比方,就像公司里的部門架構:輸入層是基層員工,隱藏層是中層管理,輸出層是高層決策。基層的每個人都要向所有中層匯報,中層再向所有高層匯報,這樣信息就能經過多層處理后得到最終結果。
但全連接網絡處理圖像時有個缺點:它會把圖像的二維像素矩陣轉換成一維向量,這就像把一張完整的圖片撕成一條線,會丟失圖像的空間特征。
二、代碼實現步驟
1. 導入必要的庫
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
這些庫就像我們的工具包:
torch
?是 PyTorch 的核心庫nn
?模塊包含神經網絡相關的工具optim
?提供優化器torchvision
?有現成的數據集和圖像處理工具DataLoader
?幫助我們批量加載數據PIL
?用于處理圖像
2. 數據準備
def build_data():transform = transforms.Compose([transforms.ToTensor(),])train_set = datasets.MNIST(root = '../dataset',train = True,download = True,transform = transform)test_set = datasets.MNIST(root = '../dataset',train = False,download = True,transform = transform)train_loader = DataLoader(dataset = train_set,batch_size = 128,shuffle = True)test_loader = DataLoader(dataset = test_set,batch_size = 64,shuffle = True)return train_loader, test_loader
這段代碼做了三件事:
- 定義了數據轉換方式,
ToTensor()
會把圖像轉換成張量并歸一化 - 加載 MNIST 數據集(手寫數字數據集,包含 0-9 共 10 類數字)
- 用
DataLoader
把數據分成批次,方便訓練時批量處理
batch_size
表示每次處理多少張圖片,shuffle=True
表示打亂數據順序,讓模型學習更全面。
3. 定義網絡結構
class MNISTNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28 * 28, 256)self.relu1 = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28) # 把28x28的圖像展平成784維向量x = self.relu1(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x
我們定義了一個 3 層的全連接網絡:
- 輸入層:MNIST 圖像是 28x28 的,展平后是 784 個像素點
- 第一個隱藏層:256 個神經元,使用 ReLU 激活函數
- 第二個隱藏層:128 個神經元,同樣使用 ReLU 激活函數
- 輸出層:10 個神經元(對應 0-9 十個數字)
激活函數 ReLU 的作用是引入非線性,讓網絡能夠學習復雜的模式,就像給計算器增加了更多運算功能。
4. 模型訓練
def train(model, train_loader, epochs):criterion = nn.CrossEntropyLoss() # 交叉熵損失函數,適合分類問題opt = optim.SGD(model.parameters(), lr=0.01) # 隨機梯度下降優化器for epoch in range(epochs):loss_sum = 0count = 0for x, y in train_loader:y_pred = model(x) # 前向傳播,得到預測結果loss = criterion(y_pred, y) # 計算損失# 反向傳播更新參數opt.zero_grad() # 清空梯度loss.backward() # 計算梯度opt.step() # 更新參數loss_sum += loss.item()_, pred = torch.max(y_pred, dim=1) # 找到概率最大的類別count += (pred == y).sum().item() # 統計正確的數量acc = count / len(train_loader.dataset) # 計算準確率print(f'epoch: {epoch+1}, Loss: {loss_sum:.4f}, Acc: {acc:.4f}')
訓練過程就像學生做習題:
- 先用當前模型做預測(前向傳播)
- 計算預測結果和正確答案的差距(損失函數)
- 分析哪里錯了,怎么改進(反向傳播計算梯度)
- 調整模型參數(優化器更新參數)
我們用交叉熵損失函數來衡量預測錯誤的程度,用隨機梯度下降(SGD)來優化模型參數,學習率lr=0.01
控制每次調整的幅度。
5. 模型保存和加載
def save_model(model, model_path):torch.save(model.state_dict(), model_path) # 保存模型參數def load_model(model_path):model = MNISTNet()model.load_state_dict(torch.load(model_path)) # 加載模型參數return model
訓練好的模型可以保存下來,下次用的時候直接加載,不用重新訓練,就像保存游戲進度一樣。
6. 預測單張圖片
def predict(model, filePath):img = Image.open(filePath)# 圖像預處理:調整大小、轉成張量、歸一化transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])t_img = transform(img)with torch.no_grad(): # 預測時不需要計算梯度y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'預測結果: {pred.item()}')
預測時需要對輸入圖片做和訓練數據相同的預處理,with torch.no_grad()
可以加快計算速度,因為預測時不需要更新參數。
7. 主函數
if __name__ == '__main__':train_loader, test_loader = build_data()model = MNISTNet()# 訓練模型train(model, train_loader, epochs=10)# 保存模型save_model(model, './mnist.pt')# 加載模型并預測model_pred = load_model('./mnist.pt')predict(model_pred, './img/3.png')
三、運行結果說明
訓練過程中,我們會看到損失(Loss)逐漸減小,準確率(Acc)逐漸提高,這說明模型在不斷進步。
對于 MNIST 這種簡單數據集,用這個全連接網絡通常能達到 97% 以上的準確率。如果想進一步提高性能,可以考慮使用卷積神經網絡(CNN),它能更好地保留圖像的空間特征。
四、小結
本文用 PyTorch 實現了一個全連接神經網絡來識別 MNIST 手寫數字,主要步驟包括:
- 準備數據:加載并預處理 MNIST 數據集
- 定義網絡:設計 3 層全連接網絡
- 訓練模型:使用交叉熵損失和 SGD 優化器
- 保存和加載模型:方便復用
- 單張圖片預測:實際應用模型
全連接網絡雖然簡單,但它是理解更復雜神經網絡的基礎。通過這個例子,我們可以了解神經網絡的基本工作原理和 PyTorch 的使用方法。