引言
本次準備建立一個卷積神經網絡模型,用于區分鳥和飛機,并從CIFAR-10數據集中選出所有鳥和飛機作為本次的數據集。
以此為例,介紹一個神經網絡模型從數據集準備、數據歸一化處理、模型網絡函數定義、模型訓練、結果驗證、模型文件保存,端到端的模型全生命周期,方便大家深入了解AI模型開發的全過程。
?
一、網絡場景定義與數據集準備
1.1 數據集準備
本次我準備使用CIFAR10數據集,它是一個簡單有趣的數據集,由60000張小RGB圖片構成(32像素*32像素),每張圖類別標簽用1~10數字表示
%matplotlib inline
from matplotlib import pyplot as pltfrom torchvision import datasets
data_path = '/content/sample_data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)type(cifar10).__mro__?

1.2 查看數據集類別示例
class_names = ['airplane', 'aotomobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(8, 3))
num_classes = 10
for i in range(num_classes):ax = fig.add_subplot(2, 5 ,1 + i, xticks=[], yticks=[])ax.set_title(class_names[i])img = next(img for img, label in cifar10 if label == i)plt.imshow(img)
plt.show()
?
1.2.1 輸出單張圖像類別及展示圖片
img, label = cifar10[99]
img, label, class_names[label]
plt.imshow(img)
plt.show()?
1.3 數據集Dataset變換
使用torchvision.transforms模塊,將PIL圖像變換為PyTorch張量,用于圖像分類
1.3.1 將單張圖像轉換為張量,輸出張量大小
from torchvision import transformsto_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape
1.3.2 將CIFAR10數據集轉換為張量
tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())tensor_cifar10.__len__()
1.4 數據歸一化
使用transforms.Compose()將圖像連接起來,在數據加載器中直接進行數據歸一化和數據增強操作
使用transforms.Normalize(),計算數據集中每個通道的平均值和標準差,使每個通道的均值為0,標準差為1
imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3)
imgs.shape
1.4.1 計算每個通道的平均值(mean)
imgs.view(3, -1).mean(dim=1)
1.4.2 計算每個通道的標準差(stdev)
imgs.view(3, -1).std(dim=1)
1.4.3 使用transforms.Normailze()對數據集歸一化
使每個數據集的通道的均值為0,標準差為1
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
二、使用nn.Module編寫第一個識別鳥與飛機的網絡模型
2.1 構建鳥與飛機的訓練集和驗證集
2.1.1 準備CIFAR10數據集
cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))
cifar10_val = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))2.1.2 構建CIFAR2-數據集
label_map = {0:0, 2:1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])for img, label in cifar10if label in [0, 2]
]cifar2.__len__()
2.1.3 構建CIFAR2-驗證集
cifar2_val = [(img, label_map[label])for img, label in cifar10_valif label in [0, 2]]cifar2_val.__len__()
2.1.4 準備批處理圖像
img, _ = cifar2[0]plt.imshow(img.permute(1, 2, 0))
plt.show()
img
img.shape2.2 編寫第一個nn.Module子模塊的網絡定義
import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.act1 = nn.Tanh()self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.act2 = nn.Tanh()self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.act3 = nn.Tanh()self.fc2 = nn.Linear(32, 2)def forward(self, x):out = self.pool1(self.act1(self.conv1(x)))out = self.pool2(self.act2(self.conv2(out)))out = out.view(-1, 8 * 8 * 8)out = self.act3(self.fc1(out))out = self.fc2(out)return out?
2.2.1 將網絡模型實例化,并輸出模型參數
model = Net()numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list
2.3 使用函數式API,優化nn.Module網絡函數定義
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.fc2 = nn.Linear(32, 2)def forward(self, x):out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)out = out.view(-1, 8 * 8 * 8)out = torch.tanh(self.fc1(out))out = self.fc2(out)return outmodel = Net()
model(img.unsqueeze(0))
2.4 定義網絡模型的訓練循環函數,并執行訓練
import datetimedef training_loop(n_epochs, optimizer, model, loss_fn, train_loader):for epoch in range(1, n_epochs + 1):loss_train = 0.0for imgs, labels in train_loader:outputs = model(imgs)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()loss_train += loss.item()if epoch == 1 or epoch %10 == 0:print('{} Epoch {}, Training loss{}'.format(datetime.datetime.now(), epoch, loss_train / len(train_loader)))train_loader = torch.utils.data.DataLoader(cifar2, batch_size = 64, shuffle=True)model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()training_loop(n_epochs = 100,optimizer = optimizer,model = model,loss_fn = loss_fn,train_loader = train_loader,?
2.4.1 訓練結果(耗時7分鐘)
2025-08-17 15:13:20.123706 Epoch 1, Training loss0.5672952472024663
2025-08-17 15:14:01.667640 Epoch 10, Training loss0.32902660861516453
2025-08-17 15:14:47.187795 Epoch 20, Training loss0.2960508146863075
2025-08-17 15:15:33.119990 Epoch 30, Training loss0.26820498961172284
2025-08-17 15:16:19.303661 Epoch 40, Training loss0.24607981879050564
2025-08-17 15:17:04.858228 Epoch 50, Training loss0.22783752284042394
2025-08-17 15:17:50.712569 Epoch 60, Training loss0.2095268357806145
2025-08-17 15:18:36.846523 Epoch 70, Training loss0.19460647420328894
2025-08-17 15:19:22.404563 Epoch 80, Training loss0.18098321051639357
2025-08-17 15:20:08.067236 Epoch 90, Training loss0.16757476806735536
2025-08-17 15:20:54.041604 Epoch 100, Training loss0.15512346253273593

2.5 測量準確率(使用驗證集)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)def validate(model, train_loader, val_loader):for name, loader in [("train", train_loader), ("val", val_loader)]:correct = 0total = 0with torch.no_grad():for imgs, labels in loader:outputs = model(imgs)_, predicted = torch.max(outputs, dim=1)total += labels.shape[0]correct += int((predicted == labels).sum())print("Accuracy {}: {:.2f}".format(name, correct/total))validate(model, train_loader, val_loader)
2.6 保存并加載我們的模型
2.6.1 保存模型
torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')
2.6.2 模型pt文件生成
包含模型的所有參數,即2個卷積模塊和2個線性模塊的權重和偏置

2.6.3 加載參數到模型實例
loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt'))
三、小結
至此,我們完成一個卷積神經網絡模型birds_vs_airplanes的構建,可用于圖像分類識別,區分圖片是鳥還是飛機,準確性高達94!
我們從數據集準備、數據集準備、數據歸一化處理、模型網絡函數定義、模型訓練、結果驗證、模型文件保存,并將模型參數加載到另一個新模型實例中,端到端完整串聯一個神經網絡模型全生命周期的過程,加深對AI模型開發的理解,這是個經典案例,快來試試吧~
?
?