【大模型微調系列-04】 神經網絡基礎與小項目實戰

【大模型微調系列-04】 神經網絡基礎與小項目實戰

💡 本章目標:通過構建一個能識別手寫數字的AI模型,讓你真正理解神經網絡是如何"學習"的。2-3小時后,你將擁有第一個自己訓練的AI模型!

4.1 理論講解:神經網絡的核心機制

4.1.1 激活函數:讓神經網絡"活"起來

想象一下,如果神經網絡只是簡單地把輸入乘以權重再相加,那會發生什么?答案是:無論疊加多少層,最終都等價于一個簡單的線性變換。這就像無論你用多少個直尺去畫圖,最終也只能畫出直線,永遠畫不出曲線。

激活函數就是神經網絡的"魔法開關",它把線性運算的結果進行非線性變換,讓網絡能夠學習復雜的模式。

輸入 x
線性變換
z = Wx + b
激活函數
a = f z
非線性輸出
為什么必須要激活函數?

讓我們用一個簡單的數學例子來說明。假設有一個兩層網絡,沒有激活函數:

第一層:y? = W?x + b?
第二層:y? = W?y? + b?
展開后:y? = W?(W?x + b?) + b? = (W?W?)x + (W?b? + b?)

看到了嗎?兩層網絡最終可以簡化為 y = Wx + b 的形式,和單層網絡沒有區別!這就是為什么我們需要激活函數來打破這種線性疊加。

三種常用激活函數

讓我們通過代碼直觀地看看這些激活函數長什么樣:

import numpy as np
import matplotlib.pyplot as plt# 創建輸入數據
x = np.linspace(-5, 5, 100)# 定義三種激活函數
def relu(x):"""ReLU: 過濾負值,保留正值"""return np.maximum(0, x)def sigmoid(x):"""Sigmoid: 壓縮到0-1之間,像概率"""return 1 / (1 + np.exp(-x))def tanh(x):"""Tanh: 壓縮到-1到1之間,中心化的sigmoid"""return np.tanh(x)# 繪制激活函數圖像
plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1)
plt.plot(x, relu(x), 'b-', linewidth=2)
plt.title('ReLU: 過濾器')
plt.xlabel('輸入')
plt.ylabel('輸出')
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linewidth=0.5)
plt.axvline(x=0, color='k', linewidth=0.5)plt.subplot(1, 3, 2)
plt.plot(x, sigmoid(x), 'r-', linewidth=2)
plt.title('Sigmoid: 概率轉換器')
plt.xlabel('輸入')
plt.ylabel('輸出')
plt.grid(True, alpha=0.3)
plt.axhline(y=0.5, color='k', linewidth=0.5, linestyle='--')plt.subplot(1, 3, 3)
plt.plot(x, tanh(x), 'g-', linewidth=2)
plt.title('Tanh: 中心化壓縮器')
plt.xlabel('輸入')
plt.ylabel('輸出')
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linewidth=0.5)
plt.axvline(x=0, color='k', linewidth=0.5)plt.tight_layout()
plt.show()

💭 理解檢查

  • ReLU為什么叫"整流"?因為它像電子元件中的整流器,只讓正信號通過
  • Sigmoid為什么常用于二分類?因為它的輸出恰好在0-1之間,可以解釋為概率
  • Tanh相比Sigmoid的優勢?輸出以0為中心,有助于下一層的學習
XOR問題:激活函數的威力展示

XOR(異或)問題是一個經典例子,它無法用直線分割,但加入激活函數后就能輕松解決:

# XOR問題的數據
X = np.array([[0,0], [0,1], [1,0], [1,1]])
y = np.array([0, 1, 1, 0])  # XOR的輸出# 沒有激活函數:永遠無法正確分類
# 有激活函數:能夠學習非線性邊界

📝 一句話總結:激活函數是神經網絡的"非線性變換器",讓網絡能夠學習曲線、圓形等復雜邊界,而不僅僅是直線。

4.1.2 前向傳播與反向傳播:神經網絡的學習機制

前向傳播:數據的流水線

想象神經網絡是一條生產流水線,原材料(輸入數據)經過每個工作站(網絡層)的加工,最終產出成品(預測結果)。

前向傳播流程
第一層
W1*x+b1
輸入
28*28圖像
ReLU
激活
第二層
W2*a1+b2
ReLU
激活
輸出層
W3*a2+b3
預測
數字0-9

每一層的計算都很簡單:

  1. 線性變換z = W×x + b (權重矩陣乘以輸入,加上偏置)
  2. 激活函數a = f(z) (引入非線性)
  3. 傳遞輸出:這一層的輸出成為下一層的輸入
反向傳播:智能的糾錯機制

如果前向傳播是"考試答題",那反向傳播就是"批改試卷"。它從最終的錯誤開始,逐層往回找每個參數應該承擔多少"責任"。

反向傳播流程
輸出層梯度
dL/dW3
計算誤差
loss
第二層梯度
dL/dW2
第一層梯度
dL/dW1
更新所有參數
W = W - lr*梯度

核心概念解釋

  • 梯度(Gradient):就是"改變的方向"。想象你在山上,梯度告訴你哪個方向下山最快
  • 鏈式法則(Chain Rule):錯誤的傳遞就像接力賽,每一層都要把"接力棒"(梯度)傳給前一層
  • 參數更新W_new = W_old - 學習率 × 梯度

讓我們用代碼演示一個簡化版的反向傳播:

# 簡化的反向傳播示例
import torch# 創建一個簡單的計算圖
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)# 前向傳播
y = w * x + b  # y = 2*1 + 1 = 3
loss = (y - 5) ** 2  # 假設目標是5,loss = (3-5)2 = 4# 反向傳播(PyTorch自動完成)
loss.backward()print(f"x的梯度: {x.grad}")  # 告訴我們x對loss的影響
print(f"w的梯度: {w.grad}")  # 告訴我們w對loss的影響
print(f"b的梯度: {b.grad}")  # 告訴我們b對loss的影響# 參數更新(梯度下降)
learning_rate = 0.1
w_new = w - learning_rate * w.grad
print(f"更新后的w: {w_new}")

🎯 動手試試:修改上面代碼中的目標值(從5改為其他數字),觀察梯度如何變化。

📝 一句話總結:前向傳播計算預測結果,反向傳播計算如何調整參數,兩者配合讓網絡不斷學習進步。

4.1.3 訓練與驗證:如何讓模型真正"學會"

數據集劃分:考試系統的智慧

訓練神經網絡就像準備考試,我們需要合理安排學習和測試:

完整數據集
70,000張圖片
訓練集 60%
42,000張圖片
用來學習
驗證集 20%
14,000張圖片
檢驗效果
測試集 20%
14,000張圖片
最終考試
  • 訓練集:相當于課本習題,模型通過它學習規律
  • 驗證集:相當于模擬考試,用來調整學習策略(超參數)
  • 測試集:相當于期末考試,只在最后用一次,評估真實能力
過擬合:死記硬背的陷阱

過擬合就像學生死記硬背答案,考試題目稍微變化就不會做了。讓我們看看過擬合的表現:

# 繪制訓練過程中的loss曲線
epochs = np.arange(1, 21)# 模擬三種情況
# 正常情況
train_loss_normal = 1.0 / np.sqrt(epochs) + 0.1
val_loss_normal = 1.0 / np.sqrt(epochs) + 0.15# 欠擬合
train_loss_under = 0.8 - 0.01 * epochs + 0.5
val_loss_under = 0.8 - 0.01 * epochs + 0.52# 過擬合
train_loss_over = 1.0 / epochs
val_loss_over = 1.0 / epochs[:10].tolist() + (0.1 + 0.02 * epochs[10:]).tolist()plt.figure(figsize=(12, 4))# 繪制三種情況
for i, (train, val, title) in enumerate([(train_loss_under, val_loss_under, '欠擬合:還沒學會'),(train_loss_normal, val_loss_normal, '正常:恰到好處'),(train_loss_over, val_loss_over, '過擬合:死記硬背')
], 1):plt.subplot(1, 3, i)plt.plot(epochs, train, 'b-', label='訓練loss', linewidth=2)plt.plot(epochs, val, 'r--', label='驗證loss', linewidth=2)plt.xlabel('訓練輪數')plt.ylabel('Loss')plt.title(title)plt.legend()plt.grid(True, alpha=0.3)plt.tight_layout()
plt.show()

防止過擬合的技巧

  1. Early Stopping(早停):驗證loss不再下降就停止訓練
  2. Dropout(隨機失活):訓練時隨機"關閉"一些神經元,防止過度依賴
  3. 數據增強:對圖片旋轉、縮放,制造更多訓練樣本

📝 一句話總結:合理劃分數據集并監控驗證指標,是讓模型真正"理解"而非"死記"的關鍵。

4.1.4 模型保存與加載:AI的"存檔系統"

訓練好的模型就像游戲存檔,需要妥善保存以便后續使用:

# 保存模型的兩種方式# 方式1:保存整個模型(結構+參數)
torch.save(model, 'my_model.pth')
# 優點:加載簡單
# 缺點:文件較大,依賴代碼版本# 方式2:只保存參數(推薦)
torch.save(model.state_dict(), 'model_params.pth')
# 優點:文件小,跨版本兼容性好
# 缺點:加載時需要先定義模型結構# 保存訓練狀態(用于中斷后繼續訓練)
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')

模型文件包含什么?

  • 網絡每一層的權重矩陣和偏置
  • 批歸一化層的統計信息(如果有)
  • 優化器的動量信息(如果保存checkpoint)

4.2 實操案例:構建你的第一個AI - MNIST手寫數字識別

現在讓我們動手實現一個完整的神經網絡項目!MNIST手寫數字識別是深度學習的"Hello World"。

4.2.1 項目概覽與環境準備

# 導入必要的庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # 顯示進度條# 檢查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用設備: {device}')# 設置隨機種子,保證結果可復現
torch.manual_seed(42)
np.random.seed(42)

4.2.2 數據加載與探索

# 定義數據預處理
transform = transforms.Compose([transforms.ToTensor(),  # 將圖片轉為Tensortransforms.Normalize((0.1307,), (0.3081,))  # 標準化(MNIST的均值和標準差)
])# 下載并加載MNIST數據集
print("正在下載MNIST數據集...")
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform
)# 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)print(f"訓練集大小: {len(train_dataset)} 張圖片")
print(f"測試集大小: {len(test_dataset)} 張圖片")
print(f"每批次大小: {batch_size}")
print(f"訓練批次數: {len(train_loader)}")# 可視化一些樣本
def show_samples():"""展示一些MNIST樣本"""fig, axes = plt.subplots(2, 5, figsize=(12, 6))axes = axes.ravel()# 獲取一批數據images, labels = next(iter(train_loader))for i in range(10):# 反標準化用于顯示img = images[i].squeeze()img = img * 0.3081 + 0.1307axes[i].imshow(img, cmap='gray')axes[i].set_title(f'標簽: {labels[i].item()}')axes[i].axis('off')plt.suptitle('MNIST手寫數字樣本')plt.tight_layout()plt.show()show_samples()# 打印數據形狀,幫助理解
sample_image, sample_label = train_dataset[0]
print(f"\n單張圖片形狀: {sample_image.shape}")  # torch.Size([1, 28, 28])
print(f"展平后的維度: {sample_image.flatten().shape}")  # torch.Size([784])

💡 理解要點

  • MNIST圖片是28×28像素的灰度圖
  • 展平后變成784維的向量,作為網絡輸入
  • 標簽是0-9的數字,需要預測10個類別

4.2.3 定義神經網絡模型

class MLP(nn.Module):"""多層感知器(MLP)網絡結構:784輸入 → 128隱藏 → 64隱藏 → 10輸出"""def __init__(self):super(MLP, self).__init__()# 定義網絡層self.fc1 = nn.Linear(784, 128)  # 輸入層到第一隱藏層self.fc2 = nn.Linear(128, 64)   # 第一隱藏層到第二隱藏層self.fc3 = nn.Linear(64, 10)    # 第二隱藏層到輸出層# 定義激活函數self.relu = nn.ReLU()# Dropout層,防止過擬合self.dropout = nn.Dropout(0.2)  # 20%的神經元會被隨機關閉def forward(self, x):"""前向傳播過程"""# 將28×28的圖片展平成784維向量x = x.view(-1, 784)  # -1表示自動計算批次大小# 第一層:線性變換 + 激活 + Dropoutx = self.fc1(x)       # [batch_size, 784] → [batch_size, 128]x = self.relu(x)      # ReLU激活x = self.dropout(x)   # Dropout# 第二層:線性變換 + 激活 + Dropoutx = self.fc2(x)       # [batch_size, 128] → [batch_size, 64]x = self.relu(x)      # ReLU激活x = self.dropout(x)   # Dropout# 輸出層:線性變換(不需要激活函數)x = self.fc3(x)       # [batch_size, 64] → [batch_size, 10]return x# 創建模型實例
model = MLP().to(device)# 打印模型結構
print("模型結構:")
print(model)# 計算模型參數量
total_params = sum(p.numel() for p in model.parameters())
print(f"\n總參數量: {total_params:,}")
MLP網絡結構
隱藏層1
128個神經元
ReLU激活
輸入層
784個神經元
28*28圖片
Dropout
20%
隱藏層2
64個神經元
ReLU激活
Dropout
20%
輸出層
10個神經元
數字0-9

4.2.4 訓練循環實現

# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()  # 交叉熵損失,適合多分類
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 記錄訓練過程
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []def train_epoch(model, loader, criterion, optimizer):"""訓練一個epoch"""model.train()  # 設置為訓練模式(啟用Dropout)total_loss = 0correct = 0total = 0# 使用tqdm顯示進度條progress_bar = tqdm(loader, desc='訓練中')for batch_idx, (images, labels) in enumerate(progress_bar):# 將數據移到GPU(如果有)images, labels = images.to(device), labels.to(device)# ========= 5步訓練流程 =========# 步驟1:清零梯度(必須的,否則梯度會累積)optimizer.zero_grad()# 步驟2:前向傳播outputs = model(images)# 步驟3:計算損失loss = criterion(outputs, labels)# 步驟4:反向傳播loss.backward()# 步驟5:更新參數optimizer.step()# ==============================# 統計指標total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 更新進度條顯示if batch_idx % 10 == 0:progress_bar.set_postfix({'Loss': f'{loss.item():.4f}','Acc': f'{100.*correct/total:.2f}%'})avg_loss = total_loss / len(loader)accuracy = 100. * correct / totalreturn avg_loss, accuracydef validate(model, loader, criterion):"""驗證模型性能"""model.eval()  # 設置為評估模式(關閉Dropout)total_loss = 0correct = 0total = 0# 不需要計算梯度,節省內存with torch.no_grad():for images, labels in tqdm(loader, desc='驗證中'):images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()avg_loss = total_loss / len(loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy# 開始訓練
num_epochs = 10
best_val_acc = 0print("="*50)
print("開始訓練...")
print("="*50)for epoch in range(num_epochs):print(f'\nEpoch [{epoch+1}/{num_epochs}]')# 訓練train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)train_losses.append(train_loss)train_accuracies.append(train_acc)# 驗證val_loss, val_acc = validate(model, test_loader, criterion)val_losses.append(val_loss)val_accuracies.append(val_acc)print(f'訓練 - Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')print(f'驗證 - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%')# 保存最佳模型if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(), 'best_model.pth')print(f'? 保存最佳模型 (驗證準確率: {val_acc:.2f}%)')print(f'\n訓練完成!最佳驗證準確率: {best_val_acc:.2f}%')

🔍 常見錯誤提示

  • 如果出現"CUDA out of memory",減小batch_size
  • 如果loss變成NaN,檢查學習率是否過大
  • 如果準確率一直不提升,檢查數據預處理是否正確

4.2.5 訓練曲線可視化

def plot_training_history():"""繪制訓練歷史曲線"""fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))epochs_range = range(1, len(train_losses) + 1)# 繪制Loss曲線ax1.plot(epochs_range, train_losses, 'b-', label='訓練Loss', linewidth=2)ax1.plot(epochs_range, val_losses, 'r-', label='驗證Loss', linewidth=2)ax1.set_xlabel('Epoch')ax1.set_ylabel('Loss')ax1.set_title('Loss曲線')ax1.legend()ax1.grid(True, alpha=0.3)# 標注最低驗證Lossmin_val_loss_epoch = np.argmin(val_losses) + 1ax1.plot(min_val_loss_epoch, val_losses[min_val_loss_epoch-1], 'ro', markersize=10)ax1.annotate(f'最低驗證Loss\nEpoch {min_val_loss_epoch}', xy=(min_val_loss_epoch, val_losses[min_val_loss_epoch-1]),xytext=(min_val_loss_epoch+1, val_losses[min_val_loss_epoch-1]+0.05),arrowprops=dict(arrowstyle='->', color='red'))# 繪制準確率曲線ax2.plot(epochs_range, train_accuracies, 'b-', label='訓練準確率', linewidth=2)ax2.plot(epochs_range, val_accuracies, 'r-', label='驗證準確率', linewidth=2)ax2.set_xlabel('Epoch')ax2.set_ylabel('準確率 (%)')ax2.set_title('準確率曲線')ax2.legend()ax2.grid(True, alpha=0.3)# 標注最高驗證準確率max_val_acc_epoch = np.argmax(val_accuracies) + 1ax2.plot(max_val_acc_epoch, val_accuracies[max_val_acc_epoch-1], 'go', markersize=10)ax2.annotate(f'最高準確率\n{val_accuracies[max_val_acc_epoch-1]:.2f}%', xy=(max_val_acc_epoch, val_accuracies[max_val_acc_epoch-1]),xytext=(max_val_acc_epoch-2, val_accuracies[max_val_acc_epoch-1]-3),arrowprops=dict(arrowstyle='->', color='green'))plt.suptitle('訓練過程監控', fontsize=14)plt.tight_layout()plt.show()plot_training_history()

4.2.6 模型評估與錯誤分析

def evaluate_model():"""詳細評估模型性能"""model.eval()# 收集所有預測結果all_predictions = []all_labels = []all_probs = []with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)# 獲取預測概率和類別probs = torch.softmax(outputs, dim=1)_, predicted = torch.max(outputs, 1)all_predictions.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())all_probs.extend(probs.cpu().numpy())# 計算混淆矩陣from sklearn.metrics import confusion_matriximport seaborn as snscm = confusion_matrix(all_labels, all_predictions)plt.figure(figsize=(10, 8))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10))plt.title('混淆矩陣')plt.ylabel('真實標簽')plt.xlabel('預測標簽')plt.show()# 找出預測錯誤的樣本errors = []for i in range(len(all_predictions)):if all_predictions[i] != all_labels[i]:errors.append({'index': i,'true': all_labels[i],'pred': all_predictions[i],'confidence': max(all_probs[i]) * 100})print(f"錯誤樣本數: {len(errors)} / {len(all_predictions)}")print(f"錯誤率: {len(errors)/len(all_predictions)*100:.2f}%")# 展示一些錯誤案例if errors:print("\n部分錯誤案例:")for error in errors[:5]:print(f"樣本{error['index']}: 真實={error['true']}, "f"預測={error['pred']}, 置信度={error['confidence']:.1f}%")return errorserrors = evaluate_model()# 可視化錯誤案例
def show_error_cases(num_cases=6):"""展示預測錯誤的案例"""if not errors:print("沒有錯誤案例!")returnfig, axes = plt.subplots(2, 3, figsize=(12, 8))axes = axes.ravel()# 隨機選擇一些錯誤案例error_indices = np.random.choice(len(errors), min(num_cases, len(errors)), replace=False)for idx, ax in enumerate(axes):if idx >= len(error_indices):ax.axis('off')continueerror = errors[error_indices[idx]]# 獲取對應的圖片test_data = test_dataset[error['index']][0]img = test_data.squeeze()img = img * 0.3081 + 0.1307  # 反標準化ax.imshow(img, cmap='gray')ax.set_title(f"真實: {error['true']}, 預測: {error['pred']}\n"f"置信度: {error['confidence']:.1f}%",color='red')ax.axis('off')plt.suptitle('預測錯誤案例分析', fontsize=14)plt.tight_layout()plt.show()show_error_cases()

4.2.7 模型保存與加載

# 保存完整的訓練狀態
def save_checkpoint(model, optimizer, epoch, loss, accuracy, filename='checkpoint.pth'):"""保存訓練檢查點"""checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,'accuracy': accuracy,'model_architecture': str(model),}torch.save(checkpoint, filename)print(f"? 檢查點已保存到 {filename}")# 保存當前模型
save_checkpoint(model, optimizer, num_epochs, val_losses[-1], val_accuracies[-1])# 加載模型進行推理
def load_model_for_inference():"""加載訓練好的模型"""# 創建新的模型實例loaded_model = MLP().to(device)# 加載參數loaded_model.load_state_dict(torch.load('best_model.pth'))loaded_model.eval()print("? 模型加載成功!")return loaded_model# 測試加載的模型
loaded_model = load_model_for_inference()# 用加載的模型預測單張圖片
def predict_single_image(model, image_tensor):"""預測單張圖片"""model.eval()with torch.no_grad():image_tensor = image_tensor.unsqueeze(0).to(device)  # 添加batch維度output = model(image_tensor)probabilities = torch.softmax(output, dim=1)predicted_class = torch.argmax(output, dim=1)confidence = torch.max(probabilities) * 100return predicted_class.item(), confidence.item()# 測試預測功能
test_image, test_label = test_dataset[0]
pred_class, confidence = predict_single_image(loaded_model, test_image)
print(f"預測結果: {pred_class}, 置信度: {confidence:.2f}%, 真實標簽: {test_label}")

4.2.8 進階:CNN網絡實現(選做)

class SimpleCNN(nn.Module):"""簡單的卷積神經網絡相比MLP的優勢:參數共享、局部感知"""def __init__(self):super(SimpleCNN, self).__init__()# 卷積層self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)# 池化層self.pool = nn.MaxPool2d(2, 2)# 全連接層self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)# 激活和Dropoutself.relu = nn.ReLU()self.dropout = nn.Dropout(0.25)def forward(self, x):# 第一個卷積塊x = self.conv1(x)           # [batch, 1, 28, 28] → [batch, 32, 28, 28]x = self.relu(x)x = self.pool(x)            # [batch, 32, 28, 28] → [batch, 32, 14, 14]# 第二個卷積塊x = self.conv2(x)           # [batch, 32, 14, 14] → [batch, 64, 14, 14]x = self.relu(x)x = self.pool(x)            # [batch, 64, 14, 14] → [batch, 64, 7, 7]# 展平并通過全連接層x = x.view(-1, 64 * 7 * 7)  # 展平x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)return x# 創建CNN模型
cnn_model = SimpleCNN().to(device)# 計算參數量對比
mlp_params = sum(p.numel() for p in model.parameters())
cnn_params = sum(p.numel() for p in cnn_model.parameters())print("模型參數量對比:")
print(f"MLP: {mlp_params:,} 參數")
print(f"CNN: {cnn_params:,} 參數")
print(f"CNN參數量是MLP的 {cnn_params/mlp_params*100:.1f}%")# 性能對比表格
comparison_data = {'模型': ['MLP', 'CNN'],'參數量': [mlp_params, cnn_params],'準確率': ['~98%', '~99%'],'訓練時間/epoch': ['~30秒', '~45秒'],'優勢': ['簡單易懂', '更高準確率']
}import pandas as pd
comparison_df = pd.DataFrame(comparison_data)
print("\n性能對比:")
print(comparison_df.to_string(index=False))

4.3 本章小結與練習

🎯 學習成果檢查清單

完成本章學習后,你應該能夠:

  • 解釋為什么神經網絡需要激活函數
  • 描述前向傳播和反向傳播的基本流程
  • 獨立完成MNIST手寫數字識別項目
  • 繪制并解讀訓練曲線
  • 保存和加載訓練好的模型
  • 識別并處理過擬合問題

💡 關鍵概念回顧

  1. 激活函數:引入非線性,讓網絡能學習復雜模式
  2. 前向傳播:數據從輸入層流向輸出層的計算過程
  3. 反向傳播:根據誤差調整網絡參數的過程
  4. 梯度下降:沿著梯度的反方向更新參數
  5. 過擬合:模型在訓練集上表現好但泛化能力差

🚀 進階練習

  1. 調參實驗

    • 改變隱藏層大小(如256、512),觀察效果
    • 調整學習率(0.01、0.001、0.0001),找出最優值
    • 修改Dropout率,觀察對過擬合的影響
  2. 數據增強

    • 對MNIST圖片進行隨機旋轉(-15°到15°)
    • 添加隨機噪聲,提高模型魯棒性
  3. 新數據集挑戰

    • 嘗試Fashion-MNIST(服裝分類)
    • 挑戰CIFAR-10(彩色圖片分類)

📚 推薦資源

  • PyTorch官方教程:https://pytorch.org/tutorials/
  • 可視化神經網絡:http://playground.tensorflow.org/
  • MNIST數據集詳情:http://yann.lecun.com/exdb/mnist/

🎉 恭喜你!

你已經成功訓練了第一個神經網絡!這是深度學習之旅的重要里程碑。下一章,我們將深入了解Qwen大模型的架構,看看這些基礎知識如何應用到數十億參數的大模型中。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/919161.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/919161.shtml
英文地址,請注明出處:http://en.pswp.cn/news/919161.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JavaWeb前端(HTML,CSS具體案例)

前言 一直在學習B站黑馬程序員蒼穹外賣。現在已經學的差不多了,但是我學習一直是針對后端開發的,前端也沒太注重去學(他大部分都給課程資料嘻嘻🤪),但我還是比較感興趣,準備先把之前學JavaWeb&…

核心數據結構:DataFrame

3.3.1 創建與訪問什么是 DataFrame?DataFrame 是 Pandas 中的核心數據結構之一,多行多列表格數據,類似于 Excel 表格 或 SQL 查詢結果。它是一個 二維表格結構,具有行索引(index)和列標簽(colu…

深入探索Go語言標準庫 net 包中的 IP 處理

深入探索Go語言標準庫 net 包中的 IP 處理 文章目錄深入探索Go語言標準庫 net 包中的 IP 處理引言核心知識type IP常用函數常用方法代碼示例常見問題1. DNS 查詢失敗怎么辦?2. 如何區分 IPv4 和 IPv6 地址?使用場景1. 服務器端編程2. 網絡監控和調試3. 防…

2.4 雙向鏈表

目錄 引入 結構定義 結構操作 初始化 插入 刪除 打印 查找 隨機位置插入 隨機位置刪除 銷毀 總結 數據結構專欄https://blog.csdn.net/xyl6716/category_13002640.html 精益求精 追求卓越 【代碼倉庫】:Code Is Here 【合作】 :apollomona…

開發指南132-DOM的寬度、高度屬性

寬度、高度類似。這里以高度為例來說明DOM中有關高度的概念:1、height取法:element.style.height說明:元素內容區域的高度,不含padding、border、margin該屬性可寫2、clientHeight取法:element..clientHeight&#xff…

魔改chromium源碼——解除 iframe 的同源策略

在進行以下操作之前,請確保已完成之前文章中提到的 源碼拉取及編譯 部分。 如果已順利完成相關配置,即可繼續執行后續操作。 同源策略限制了不同源(協議、域名、端口)的網頁腳本訪問彼此的資源。iframe 的跨域限制由 Blink 渲染引擎和 Chromium 的安全層共同實現。 咱們直…

在鴻蒙中實現深色/淺色模式切換:從原理到可運行 Demo

摘要 現在幾乎所有主流應用都支持“深色模式”和“淺色模式”切換,這已經成了用戶習慣。鴻蒙(HarmonyOS)同樣提供了兩種模式(dark / light),并且支持應用根據系統主題切換,或者應用內手動切換。…

Redux搭檔Next.js的簡明使用教程

Redux 是一個用于 JavaScript 應用的狀態管理庫,主要解決組件間共享狀態和復雜狀態邏輯的問題。當應用規模較大、組件層級較深或多個組件需要共享/修改同一狀態時,Redux 可以提供可預測、可追蹤的狀態管理方式,避免狀態在組件間混亂傳遞。Red…

SCAI采用公平發射機制成功登陸LetsBonk,60%代幣供應量已鎖倉

去中心化科學(DeSci)平臺SCAI宣布,其代幣已于今日以Fair Launch形式在LetsBonk.fun平臺成功發射。為保障資金安全與透明,開發團隊已將代幣總量的60%進行鎖倉,進一步提升社區信任與項目合規性。SCAI是一個專注于高質量科…

【Kubernetes系列】Kubernetes中的resources

博客目錄1. limits(資源上限)2. requests(資源請求)關鍵區別其他注意事項示例場景在 Kubernetes (k8s) 中,resources 用于定義容器的資源請求(requests)和限制(limits)&a…

hadoop 前端yarn 8088端口查看任務執行情況

圖中資源相關參數含義及簡單分析思路&#xff1a; 基礎資源搶占參數 Total Resource Preempted: <memory:62112, vCores:6> 含義&#xff1a;應用總共被搶占的資源量&#xff0c; memory:62112 表示累計被收回的內存&#xff08;單位通常是MB &#xff0c;結合Hadoop生態…

基于SpringBoot的個性化教育學習平臺的設計與實現(源碼+lw+部署文檔+講解等)

課題介紹在教育數字化轉型與學習者需求差異化的背景下&#xff0c;傳統學習平臺 “統一內容、統一進度” 的模式已顯局限。當前&#xff0c;平臺多提供標準化課程資源&#xff0c;無法根據學習者年齡、基礎、目標&#xff08;如升學、技能提升&#xff09;定制學習路徑&#xf…

UE5多人MOBA+GAS 48、制作閃現技能

文章目錄添加標簽添加GA_Blink添加標簽 CRUNCH_API UE_DECLARE_GAMEPLAY_TAG_EXTERN(Ability_Blink_Teleport)CRUNCH_API UE_DECLARE_GAMEPLAY_TAG_EXTERN(Ability_Blink_Cooldown)UE_DEFINE_GAMEPLAY_TAG_COMMENT(Ability_Blink_Teleport, "Ability.Blink.Teleport"…

Swift 實戰:實現一個簡化版的 Twitter(LeetCode 355)

文章目錄摘要描述示例解決答案設計思路題解代碼分析測試示例和結果時間復雜度空間復雜度總結摘要 在社交媒體平臺里&#xff0c;推送機制是核心功能之一。比如你關注了某人&#xff0c;就希望在自己的時間線上能看到他們的最新消息&#xff0c;同時自己的消息也要能出現在別人…

在瀏覽器端使用 xml2js 遇到的報錯及解決方法

在瀏覽器端使用 xml2js 遇到的報錯及解決方法 一、引言 在前端開發過程中&#xff0c;我們常常需要處理 XML 數據。xml2js 是一個非常流行的用于將 XML 轉換為 JavaScript 對象的庫。然而&#xff0c;當我們在瀏覽器端使用它時&#xff0c;可能會遇到一些問題。本文將介紹在瀏覽…

eChart餅環pie中間顯示總數_2個以上0值不擠掉

<!DOCTYPE html> <html> <head><meta charset"utf-8"><title>環餅圖顯示總數</title><script src"https://cdn.jsdelivr.net/npm/echarts5.4.3/dist/echarts.min.js"></script><style>#main { widt…

Ansible 核心功能進階:自動化任務的靈活控制與管理

一、管理 FACTS&#xff1a;獲取遠程主機的 “身份信息”FACTS 是 Ansible 自動收集的遠程主機詳細信息&#xff08;類似 “主機身份證”&#xff09;&#xff0c;包括主機名、IP、系統版本、硬件配置等。通過 FACTS 可以動態獲取主機信息&#xff0c;讓 Playbook 更靈活1. 查看…

gRPC網絡模型詳解

gRPC協議框架 TCP層&#xff1a;底層通信協議&#xff0c;基于TCP連接。 TLS層&#xff1a;該層是可選的&#xff0c;基于TLS加密通道。 HTTP2層&#xff1a;gRPC承載在HTTP2協議上&#xff0c;利用了HTTP2的雙向流、流控、頭部壓縮、單連接上的多 路復用請求等特性。 gRPC層…

[優選算法專題二滑動窗口——將x減到0的最小操作數]

題目鏈接 將x減到0的最小操作數 題目描述 題目解析 問題重述 給定一個整數數組 nums 和一個整數 x&#xff0c;每次只能從數組的左端或右端移除一個元素&#xff0c;并將該元素的值從 x 中減去。我們需要找到將 x 恰好減為 0 的最少操作次數&#xff0c;如果不可能則返回 -…

AOP配置類自動注入

本文主要探究AopAutoConfiguration配置類里面的bean怎么被自動裝配的。代碼如下&#xff1a;package com.example.springdemo.demos.a05;import com.example.springdemo.demos.a04.Bean1; import com.example.springdemo.demos.a04.Bean2; import com.example.springdemo.demos…