流程
- 定義自定義數據集類
- 定義訓練和驗證的數據增強
- 定義模型、損失函數和優化器
- 訓練循環,包括驗證
- 訓練可視化
- 整個流程
- 模型評估
- 高級功能擴展
- 混合精度訓練?
- 分布式訓練?
{:width=“50%” height=“50%”}
定義自定義數據集類
#======================
#1. 自定義數據集類
#======================
class CustomImageDataset(Dataset):def __init__(self, root_dir, transform=None):"""自定義數據集初始化:param root_dir: 數據集根目錄:param transform: 數據增強和預處理"""self.root_dir = root_dirself.transform = transformself.classes = sorted(os.listdir(root_dir))self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}# 收集所有圖像路徑和標簽self.image_paths = []self.labels = []for cls_name in self.classes:cls_dir = os.path.join(root_dir, cls_name)for img_name in os.listdir(cls_dir):if img_name.lower().endswith(('.jpg', '.png', '.jpeg')):self.image_paths.append(os.path.join(cls_dir, img_name))self.labels.append(self.class_to_idx[cls_name])def __len__(self):return len(self.image_paths)def __getitem__(self, idx):# 加載圖像img_path = self.image_paths[idx]try:image = Image.open(img_path).convert('RGB')except Exception as e:print(f"Error loading image {img_path}: {e}")# 返回空白圖像作為占位符image = Image.new('RGB', (224, 224), (0, 0, 0))# 應用數據增強和預處理if self.transform:image = self.transform(image)# 獲取標簽label = self.labels[idx]return image, label
定義訓練和驗證的數據增強
#======================
#2. 數據增強與預處理
#======================
def get_transforms():"""返回訓練和驗證的數據增強管道"""# 訓練集增強(更豐富)train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 驗證集預處理(無隨機增強)val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return train_transform, val_transform
定義模型、損失函數和優化器
#======================
#3. 模型定義
#======================
def create_model(num_classes):"""創建模型(使用預訓練ResNet18)"""model = resnet18(pretrained=True)num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, num_classes)return model
訓練循環,包括驗證
#======================
#4. 訓練函數
#======================
def train_model(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs=25, checkpoint_path='checkpoint.pth', resume=False):"""訓練模型并支持中斷恢復:param resume: 是否從檢查點恢復訓練"""# 訓練歷史記錄history = {'train_loss': [], 'val_loss': [],'train_acc': [], 'val_acc': [],'epoch': 0, 'best_acc': 0.0}# 從檢查點恢復start_epoch = 0if resume and os.path.exists(checkpoint_path):print(f"Loading checkpoint from {checkpoint_path}")checkpoint = torch.load(checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])history = checkpoint['history']start_epoch = history['epoch'] + 1print(f"Resuming training from epoch {start_epoch}")# 訓練循環for epoch in range(start_epoch, num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 更新歷史記錄history['epoch'] = epoch# 每個epoch都有訓練和驗證階段for phase in ['train', 'val']:if phase == 'train':model.train() # 設置訓練模式else:model.eval() # 設置評估模式running_loss = 0.0running_corrects = 0# 迭代數據for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 梯度清零optimizer.zero_grad()# 前向傳播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 訓練階段反向傳播和優化if phase == 'train':loss.backward()optimizer.step()# 統計running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# 記錄歷史history[f'{phase}_loss'].append(epoch_loss)history[f'{phase}_acc'].append(epoch_acc.item())print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 保存最佳模型if phase == 'val' and epoch_acc > history['best_acc']:history['best_acc'] = epoch_acc.item()torch.save(model.state_dict(), 'best_model.pth')print(f"New best model saved with accuracy: {epoch_acc:.4f}")# 保存檢查點(每個epoch結束后)checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),'history': history}torch.save(checkpoint, checkpoint_path)print(f"Checkpoint saved at epoch {epoch+1}")print()# 保存最終模型torch.save(model.state_dict(), 'final_model.pth')print('Training finished!')return model, history
訓練可視化
#======================
#5. 可視化訓練歷史
#======================
def plot_history(history):plt.figure(figsize=(12, 4))# 損失曲線plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')# 準確率曲線plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.savefig('training_history.png')plt.show()
整個流程
#======================
#6. 主函數
#======================
def main():# 設置隨機種子(確保可復現性)torch.manual_seed(42)np.random.seed(42)# 檢查設備device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 創建數據增強管道train_transform, val_transform = get_transforms()# 創建數據集train_dataset = CustomImageDataset(root_dir='path/to/your/train_data', # 替換為你的訓練數據路徑transform=train_transform)val_dataset = CustomImageDataset(root_dir='path/to/your/val_data', # 替換為你的驗證數據路徑transform=val_transform)# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True)val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False,num_workers=4,pin_memory=True)dataloaders = {'train': train_loader, 'val': val_loader}# 創建模型num_classes = len(train_dataset.classes)model = create_model(num_classes)model = model.to(device)# 定義損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 檢查是否要恢復訓練resume_training = Falsecheckpoint_path = 'checkpoint.pth'# 檢查是否存在檢查點文件if os.path.exists(checkpoint_path):print("Checkpoint file found. Do you want to resume training? (y/n)")response = input().lower()if response == 'y':resume_training = True# 開始訓練start_time = time.time()model, history = train_model(model=model,dataloaders=dataloaders,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,num_epochs=25,checkpoint_path=checkpoint_path,resume=resume_training)end_time = time.time()# 保存訓練歷史with open('training_history.json', 'w') as f:json.dump(history, f, indent=4)# 打印訓練時間training_time = end_time - start_timeprint(f"Total training time: {training_time//3600}h {(training_time%3600)//60}m {training_time%60:.2f}s")# 可視化訓練歷史plot_history(history)if __name__ == "__main__":main()
模型評估
#======================
#模型評估
#======================
def evaluate_model(model, dataloader, device):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')return accuracy
test_dataset = CustomImageDataset('path/to/test_data', transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
evaluate_model(model, test_loader, device)
高級功能擴展
混合精度訓練?
from torch.cuda.amp import autocast, GradScaler
#在訓練函數中添加
scaler = GradScaler()
#修改訓練循環
with autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
分布式訓練?
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
#初始化分布式環境
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
#包裝模型
model = DDP(model.to(local_rank), device_ids=[local_rank])
#修改數據加載器
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(..., sampler=train_sampler)