Task: MLP神經網絡的訓練
1.PyTorch和cuda的安裝
2.查看顯卡信息的命令行命令(cmd中使用)
3.cuda的檢查
4.簡單神經網絡的流程
a.數據預處理(歸一化、轉換成張量)
b.模型的定義
i.繼承nn.Module類
ii.定義每一個層
iii.定義前向傳播流程
c.定義損失函數和優化器
d.定義訓練流程
e.可視化loss過程
MLP神經網絡訓練復習
1. PyTorch 與 CUDA 安裝
- PyTorch安裝:推薦使用官方命令(根據你的CUDA版本)例如:
pip install torch torchvision torchaudio
或使用conda:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
- CUDA支持:確保你的GPU支持CUDA,安裝匹配版本的顯卡驅動和CUDA Toolkit。
2. 查看顯卡信息(命令行)
- Windows CMD:
nvidia-smi
- 查看GPU詳細信息(在PyTorch中也可以用代碼查詢)
3. CUDA的檢查
在Python中:
import torch
print(torch.cuda.is_available()) # 查看CUDA是否可用
print(torch.cuda.device_count()) # 當前GPU數量
print(torch.cuda.get_device_name(0)) # GPU設備名
4. 簡單神經網絡流程
a. 數據預處理
- 歸一化:將數據縮放到某個范圍(通常0-1或-1到1)
- 轉換為張量:
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)
])
b. 模型定義
- 繼承
nn.Module
- 定義網絡層
- 實現
forward()
方法
示例:
import torch.nn as nnclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x
c. 損失函數和優化器
- 損失函數:
nn.CrossEntropyLoss()
,nn.MSELoss()
等 - 優化器:
torch.optim.SGD
,torch.optim.Adam
示例:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
d. 訓練流程
- 遍歷數據
- 清零梯度
optimizer.zero_grad()
- 前向傳播
- 計算損失
- 反向傳播
loss.backward()
- 更新參數
optimizer.step()
示例:
for epoch in range(num_epochs):for data, labels in dataloader:data = data.to(device)labels = labels.to(device)outputs = model(data)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()
e. 可視化Loss
用matplotlib繪制訓練過程中loss變化:
import matplotlib.pyplot as pltlosses = []# 在訓練循環中每輪加入損失
losses.append(loss.item())plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()
小結
- 先配置好環境(PyTorch、CUDA)
- 理解神經網絡的訓練流程:數據預處理 -> 模型定義 -> 損失函數/優化器 -> 訓練循環 -> 可視化
- 牢記模型定義中的繼承nn.Module的重要性
- 熟悉GPU利用(cuda)檢測及使用技巧