如何用torch框架訓練深度學習模型(詳解)
0. 需要的包
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
1. 數據加載和導入
以MNIST數據集為例
# 1.1 需要設置數據歸一化
train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# 1.2 用dataset.MNIST函數下載和加載訓練集與測試集
train_dataset = datasets.MNIST(dataset_path, train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(dataset_path, train=False, download=False, transform=test_transform)
# 1.3 加載進dataload用于后續數據按batch取用
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
補充:這里的transform根據不同的數據集選擇不同的值
datasets加載數據集時path的路徑為:'.\data\'
該目錄下包括\MNIST
文件夾
2. 加載模型和設置超參數
# 2.1 這里需要提前定義model的class,包括層結構和forward函數
model = LeNet_Mnist().to(device)
# 2.2 設置優化器、損失函數、訓練輪次
learning_rate = 1e-2
# 傳入模型參數,用于優化更新
sgd = SGD(model.parameters(), lr=learning_rate)
loss_fn = CrossEntropyLoss()
all_epoch = 20
3. 訓練
# 3.1 首先設置訓練模式
model.train()
# 3.2 按照batch從train_loader中批量選擇數據
for idx, (train_x, train_label) in enumerate(train_loader):train_x = train_x.to(device)train_label = train_label.to(device)sgd.zero_grad()predict_y = model(train_x.float())loss = loss_fn(predict_y, train_label.long())loss.backward()sgd.step()
補充:可以在外面再套一層迭代次數
for current_epoch in range(all_epoch): # local training
4. 測試
# 4.1 記錄測試結果
all_correct_num = 0
all_sample_num = 0
# 4.2 進入模型驗證模式,該模式下不會修改梯度
model.eval()
# 4.3 按批次測試
for idx, (test_x, test_label) in enumerate(test_loader):test_x = test_x.to(device)test_label = test_label.to(device)predict_y = model(test_x.float()).detach()predict_y = torch.argmax(predict_y, dim=-1)current_correct_num = predict_y == test_labelall_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)all_sample_num += current_correct_num.shape[0]
# 4.4 記錄結果并輸出
acc = all_correct_num / all_sample_num
print('accuracy: {:.3f}'.format(acc), flush=True)
5. 保存結果
# 5.1 保存參數
print("Save the model state dict")
torch.save(model.state_dict(), "./lenet_mnist.pt")
# 5.2 或者也可以選擇保存checkpoint,每輪都保存一次,萬一中斷能繼續
checkpoint = {"model": model.state_dict(),"optim": sgd.state_dict(),}
print("Save the checkpoint")
torch.save(checkpoint, "./checkpoint{}.pt".format(current_epoch))