文章目錄
- 模型訓練驗證
- 損失函數和優化器
- 模型優化
- 訓練函數
- 驗證函數
- 模型保存
模型訓練驗證
損失函數和優化器
loss_function = nn.CrossEntropyLoss() # 損失函數
optimizer = Adam(model.parameters()) # 優化器,優化參數
模型優化
獲得模型所有的可訓練參數(比如每一層的權重、偏置),設置優化器類型,自動調整學習步長(自適應學習率),后續訓練更新參數。
# 雇傭Adam教練,讓他管理模型參數
optimizer = Adam(model.parameters(), lr=0.001) # lr是初始學習率
# 1. optimizer.zero_grad() # 清空上一輪的成績單
# 2. loss.backward() # 計算每個參數要改進的方向(梯度)
# 3. optimizer.step() # 參數調整
訓練函數
def train():loss = 0accuracy = 0model.train()for x, y in train_loader: # 獲得每個batch數據x, y = x.to(device), y.to(device)output = model(x) # 得到預測labeloptimizer.zero_grad() # 梯度清零batch_loss = loss_function(output, y) # 計算batch誤差batch_loss.backward() # 計算誤差梯度optimizer.step() # 調整模型參數loss += batch_loss.item()accuracy += get_batch_accuracy(output, y, train_N)print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
驗證函數
def validate():loss = 0accuracy = 0model.eval() # 評估模式,關閉隨機性等增加穩定性with torch.no_grad(): # 禁用梯度,提高效率for x, y in valid_loader:x, y = x.to(device), y.to(device)output = model(x) # 不用進行梯度計算、參數調整loss += loss_function(output, y).item()accuracy += get_batch_accuracy(output, y, valid_N)print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
模型保存
.pth 文件是PyTorch模型的“存檔文件”,保存了所有必要信息。加載后,模型即可直接運行,無需重新訓練!
# 保存整個模型(結構 + 參數)
torch.save(model, 'model.pth')
.pth 文件可以用https://netron.app/查看