?一、需求描述
實戰四分為三部分來實現,第二部分是基于PyTorch的貓狗圖像可視化訓練的教程,實現了一個完整的貓狗分類模型訓練流程,使用預訓練的ResNet50模型進行遷移學習,并通過SwanLab進行實驗跟蹤。
效果圖

?二、實現思路
總體思路
- 導入和初始化配置:設置訓練超參數(學習率、批次大小、訓練輪數等);
- 加載數據集:讀取自定義數據集,并設置數據加載器;
- 模型構建:加載預訓練的ResNet50模型,并修改全連接層適配二分類任務;
- 訓練配置:定義交叉熵損失函數,設置Adam優化器;
- 模型訓練:循環遍歷訓練輪次,在每輪次遍歷每個批次的數據,并實時打印訓練進度及記錄損失值到SwanLab。
2.1 導入和初始化配置
import swanlab
num_epochs=20
lr=1e-4
batch_size=8
num_classes=2
device="cuda"swanlab.init(experiment_name="模型訓練實驗",description="貓狗分類",mode="local",config={"model":"resnet50","optim":"Adam","lr":lr,"batch_size":batch_size,"num_epochs":num_epochs,"num_class":num_classes,"device":device,}
)
import swanlab
- 導入SwanLab庫,用于實驗跟蹤和可視化num_epochs=20
- 設置訓練輪數為20輪lr=1e-4
- 設置學習率為0.0001batch_size=8
- 設置批次大小為8num_classes=2
- 設置分類類別數為2(貓和狗)device="cuda"
- 設置使用GPU進行訓練swanlab.init()
- 初始化SwanLab實驗,記錄實驗配置參數
2.2 加載數據集
import readDataset
from torch.utils.data import DataLoader
train_dataset=readDataset.DatasetLoader(readDataset.ds_train)
train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))
import readDataset
- 導入自定義的數據集讀取模塊from torch.utils.data import DataLoader
- 導入PyTorch的數據加載器train_dataset=readDataset.DatasetLoader(readDataset.ds_train)
- 創建訓練數據集對象train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))
- 創建數據加載器,設置批次大小并啟用隨機打亂
2.3 模型構建
import torch
import torchvision
from torchvision.models import ResNet50_Weightsmodel=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features=model.fc.in_features
model.fc=torch.nn.Linear(in_features,num_classes)
model.to(device)
import torch
- 導入PyTorch深度學習框架import torchvision
- 導入計算機視覺庫from torchvision.models import ResNet50_Weights
- 導入ResNet50預訓練權重model=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
- 加載預訓練的ResNet50模型in_features=model.fc.in_features
- 獲取全連接層的輸入特征數model.fc=torch.nn.Linear(in_features,num_classes)
- 替換最后的全連接層,輸出類別數為2model.to(device)
- 將模型移動到GPU設備
2.4 訓練配置
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
criterion=torch.nn.CrossEntropyLoss()
- 定義交叉熵損失函數,適用于多分類問題optimizer=torch.optim.Adam(model.parameters(),lr=lr)
- 定義Adam優化器,設置學習率
2.5 模型訓練
for epoch in range(num_epochs):model.train()for iter,(inputs,labels) in enumerate(train_loader):inputs,labels=inputs.to(device),labels.to(device)optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()print('Epoch[{}/{}],Iteration[{}/{}],Loss:{:.4f}'.format(epoch+1,num_epochs,iter+1,len(train_loader),loss.item()))swanlab.log({"train_loss":loss.item()})
for epoch in range(num_epochs):
- 外層循環,遍歷每個訓練輪次model.train()
- 設置模型為訓練模式for iter,(inputs,labels) in enumerate(train_loader):
- 內層循環,遍歷每個批次的數據inputs,labels=inputs.to(device),labels.to(device)
- 將輸入數據和標簽移動到GPUoptimizer.zero_grad()
- 清空梯度outputs=model(inputs)
- 前向傳播,獲取模型預測結果loss=criterion(outputs,labels)
- 計算損失loss.backward()
- 反向傳播,計算梯度optimizer.step()
- 更新模型參數print(...)
- 打印訓練進度和損失值swanlab.log({"train_loss":loss.item()})
- 記錄損失值到SwanLab實驗跟蹤系統
三、完整代碼
import swanlab
num_epochs=20
lr=1e-4
batch_size=8
num_classes=2
device="cuda"swanlab.init(experiment_name="模型訓練實驗",description="貓狗分類",mode="local",config={"model":"resnet50","optim":"Adam","lr":lr,"batch_size":batch_size,"num_epochs":num_epochs,"num_class":num_classes,"device":device,}
)import readDataset
from torch.utils.data import DataLoader
train_dataset=readDataset.DatasetLoader(readDataset.ds_train)
train_loader=(DataLoader(train_dataset,batch_size=batch_size,shuffle=True))import torch
import torchvision
from torchvision.models import ResNet50_Weightsmodel=torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features=model.fc.in_features
model.fc=torch.nn.Linear(in_features,num_classes)
model.to(device)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=lr)for epoch in range(num_epochs):model.train()for iter,(inputs,labels) in enumerate(train_loader):inputs,labels=inputs.to(device),labels.to(device)optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()print('Epoch[{}/{}],Iteration[{}/{}],Loss:{:.4f}'.format(epoch+1,num_epochs,iter+1,len(train_loader),loss.item()))swanlab.log({"train_loss":loss.item()})
四、效果展示
- PyCharm運行日志
- PyCharm終端日志
- SwanLab工作區
- 模擬訓練實驗的概覽
- 模擬訓練實驗的實驗圖表
- 模擬訓練實驗的日志
- 模擬訓練實驗的實驗環境
五、問題與解決
問題一:ModuleNotFoundError: No module named ‘XXX’
解決一:pip install XXX
pip install 'swanlab[dashboard]'