Pytorhc Lightning進階:一篇實例玩轉Pytorhc Lightning 讓訓練更高效
Pytorhc Lightning 主要包含以下幾大類,主要圍繞以下講解:
- 模型,PyTorch Lightning 的核心是繼承
pl.LightningModule
- 數據,數據模塊繼承
pl.LightningDataModule
- 回調函數的構造和使用,以及自定義
- 鉤子函數使用,模型中、數據類中、回調函數中調用先后
- 日志記錄,一般用Pytorhc Lightning自帶的tensorboard
1. 定義模型類
1.1 如下模型的基本結構
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() self.net = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU(), nn.Linear(hidden, num_class) ) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x): x = self.flatten(x) return self.net(x) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("train_loss", loss) # 自動記錄訓練損失 return loss def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自動記錄驗證損失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自動記錄測試損失 return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3)
1.2 方法validation_step
training_step
test_step
傳入兩個必須參數是batch
,batch_idx
,名字不能改。當然還有其他的,但是基本不怎么用。
batch
就是從dataloader
中返回的結果batch_idx
記錄當前epoch數據批次的索引,比如我想每隔100step去記錄日志,可以if batch_idx % 100 == 0:
training_step
需要返回進行反向傳播的loss,其他兩個可以不用,多個優化器需要手動實現反向傳播(高版本)validation_step
training_step
訓練時一定要有;test_step
是模型測試的時候調用,訓練時可以不要
1.3 優化器和調度定義方法configure_optimizers
該方法不需要額外傳參,如果有多個優化器則按照return [optimizer1, optimizer1...],[scheduler1, scheduler2,...]
當然還可以按照字典形式返回,這里拿以列表返回舉例。
1.4 多個優化器手動實現反向傳播
我們改寫一下1.1 的模型來看看怎么實現多個優化器情況。這里假設有多個優化器
class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() #################### # 1. 修改模型為兩部分方便應用不同的優化器 self.feature_extractor = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU() ) self.classifier = nn.Sequential( nn.Linear(hidden, num_class) ) #關閉自動優化 self.automatic_optimization = False #################### self.loss_fn = nn.CrossEntropyLoss() # 增加保存參數 self.save_hyperparameters() def forward(self, x): x = self.flatten(x) features = self.feature_extractor(x) return self.classifier(features) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) opt1, opt2 = self.optimizers() #3,手動優化 if batch_idx % 2 != 0: opt1.zero_grad() self.manual_backward(loss) opt1.step() # 每2步更新一次分類器 if batch_idx % 2 == 0: opt2.zero_grad() self.manual_backward(loss) opt2.step() self.log("train_loss", loss) # 自動記錄訓練損失 # 不在返回loss,因為已經手動實現 return None def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自動記錄驗證損失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自動記錄測試損失 return loss def configure_optimizers(self): #2. 這里對不同參數應用不同優化和調度, optimizer1 = torch.optim.Adam( self.feature_extractor.parameters(), lr=1e-3, weight_decay=1e-4 ) optimizer2 = torch.optim.SGD( self.classifier.parameters(), lr=1e-2, momentum=0.9 ) # 假如這里只需對第二個優化器進行學習率調度 scheduler2 = torch.optim.lr_scheduler.MultiStepLR( optimizer2, milestones=[1,3,5], gamma=0.5, last_epoch=-1) # 返回值return [optimizer1, optimizer2], scheduler2 # 4. 增加鉤子函數,每個batch后手動更新學習率 def on_train_batch_end(self, outputs, batch, batch_idx): # 獲取調度器 scheduler2 = self.lr_schedulers() # 更新調度器(按批次) scheduler2.step() # 記錄學習率 lr = scheduler2.get_last_lr()[0] self.log("lr", lr, prog_bar=True) # 5. 增加鉤子函數,模型中保存參數配置 def on_save_checkpoint(self, checkpoint: dict) -> None: checkpoint["save_data"] = self.hparams
2 數據
class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super().__init__() self.batch_size = batch_size # 數據處理部分,只會加載一次 def prepare_data(self): # 下載數據集 MNIST(root="data", train=True, download=True) MNIST(root="data", train=False, download=True) # 在分布式訓練的時候,每個進程都會加載 def setup(self, stage=None): # 數據預處理和劃分 transform = ToTensor() mnist_full = MNIST(root="data", train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_test = MNIST(root="data", train=False, transform=transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=self.batch_size) def test_dataloader(self): # 當然如果沒有test,只需要返回None return DataLoader(self.mnist_test, batch_size=self.batch_size)
以上是數據模塊的基本結構
3 回調函數定義
下面是一個定義的例子,作用是保存config
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from typing import Dict, Any
import os
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from omegaconf import OmegaConfclass SetupCallback(Callback): def __init__(self, now, cfgdir, config): super().__init__() self.now = now self.cfgdir = cfgdir self.config = config def on_train_start(self, trainer, pl_module): # 只在主進程保留 if trainer.global_rank == 0: print("Project config") OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) # 把config 也存到模型中 def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: checkpoint["cfg_project"] = self.config
這里on_train_start``on_save_checkpoint
就是定義在回調函數中的鉤子函數,下面介紹這幾個鉤子函數使用方法
4 幾個常見的鉤子函數使用
以下是三個鉤子函數在模型類和回調函數中使用方法及調用順序的簡明對比表格:
鉤子函數 | 位置 | 方法簽名 | 典型用途 | 調用順序 | 執行頻率 |
---|---|---|---|---|---|
on_fit_start | 模型類 | def on_fit_start(self) | 全局初始化、分布式設置 | 先執行 | 整個訓練過程一次 |
回調函數 | def on_fit_start(self, trainer, pl_module) | 準備日志系統、設置全局狀態 | 后執行 | 整個訓練過程一次 | |
on_train_start | 模型類 | def on_train_start(self) | 初始化訓練指標、計時器 | 先執行 | 每個訓練循環一次 |
回調函數 | def on_train_start(self, trainer, pl_module) | 重置回調狀態、訓練前準備 | 后執行 | 每個訓練循環一次 | |
on_save_checkpoint | 模型類 | def on_save_checkpoint(self, checkpoint) | 保存模型額外狀態 | 后執行 | 每次保存檢查點時 |
回調函數 | def on_save_checkpoint(self, trainer, pl_module, checkpoint) | 保存回調狀態、添加元數據 | 先執行 | 每次保存檢查點時 |
調用順序示意圖
訓練開始
├── on_fit_start
│ ├── 模型類.on_fit_start (1)
│ └── 回調函數.on_fit_start (2)
│
├── on_train_start
│ ├── 模型類.on_train_start (3)
│ └── 回調函數.on_train_start (4)
│
├── 訓練周期...
│
└── 保存檢查點├── 回調函數.on_save_checkpoint (5)└── 模型類.on_save_checkpoint (6)
5 完整代碼
simple.yaml
trainer:accelerator: gpudevices: [1]max_epochs: 100
call_back:modelckpt:filename: "{epoch:03}-{train_loss:.2f}-{val_loss:.2f}"save_top_k: -1save_step: trueevery_n_epochs: 10verbose: truesave_last: true
model:target: main_pl.LitModelparams:imgsize: 28hidden: 128num_class: 10
data:target: main_pl.MNISTDataModuleparams:batch_size: 32
main_pl.py
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from typing import Dict, Any
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from omegaconf import OmegaConf
import argparse, os, datetime, importlib, glob
from pytorch_lightning import Trainer def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) # 1. 定義 LightningModule 子類
class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() #################### # 1. 修改模型為兩部分方便應用不同的優化器 self.feature_extractor = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU() ) self.classifier = nn.Sequential( nn.Linear(hidden, num_class) ) #關閉自動優化 self.automatic_optimization = False #################### self.loss_fn = nn.CrossEntropyLoss() # 增加保存參數 self.save_hyperparameters() def forward(self, x): x = self.flatten(x) features = self.feature_extractor(x) return self.classifier(features) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) opt1, opt2 = self.optimizers() #3,手動優化 if batch_idx % 2 != 0: opt1.zero_grad() self.manual_backward(loss) opt1.step() # 每2步更新一次分類器 if batch_idx % 2 == 0: opt2.zero_grad() self.manual_backward(loss) opt2.step() self.log("train_loss", loss) # 自動記錄訓練損失 # 不在返回loss,因為已經手動實現 return None def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自動記錄驗證損失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自動記錄測試損失 return loss def configure_optimizers(self): #2. 這里對不同參數應用不同優化和調度, optimizer1 = torch.optim.Adam( self.feature_extractor.parameters(), lr=1e-3, weight_decay=1e-4 ) optimizer2 = torch.optim.SGD( self.classifier.parameters(), lr=1e-2, momentum=0.9 ) # 假如這里只需對第二個優化器進行學習率調度 scheduler2 = torch.optim.lr_scheduler.MultiStepLR( optimizer2, milestones=[1,3,5], gamma=0.5, last_epoch=-1) # 返回值return [optimizer1, optimizer2], scheduler2 # 4. 增加鉤子函數,每個batch后手動更新學習率 def on_train_batch_end(self, outputs, batch, batch_idx): # 獲取調度器 scheduler2 = self.lr_schedulers() # 更新調度器(按批次) scheduler2.step() # 記錄學習率 lr = scheduler2.get_last_lr()[0] self.log("lr", lr, prog_bar=True) def on_save_checkpoint(self, checkpoint: dict) -> None: checkpoint["save_data"] = self.hparams # 2. 準備數據模塊
class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super().__init__() self.batch_size = batch_size # 數據處理部分,只會加載一次 def prepare_data(self): # 下載數據集 MNIST(root="data", train=True, download=True) MNIST(root="data", train=False, download=True) # 在分布式訓練的時候,每個進程都會加載 def setup(self, stage=None): # 數據預處理和劃分 transform = ToTensor() mnist_full = MNIST(root="data", train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_test = MNIST(root="data", train=False, transform=transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=self.batch_size) def test_dataloader(self): # 當然如果沒有test,只需要返回None return DataLoader(self.mnist_test, batch_size=self.batch_size) class SetupCallback(Callback): def __init__(self, now, cfgdir, config): super().__init__() self.now = now self.cfgdir = cfgdir self.config = config def on_train_start(self, trainer, pl_module): # 只在主進程保留 if trainer.global_rank == 0: print("Project config") OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) # 把config 也存到模型中 def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: checkpoint["cfg_project"] = self.config if __name__ == '__main__': now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") configs_path = 'simple.yaml' logdir = os.path.join("logs", now) config = OmegaConf.load(configs_path) #0. 從config中讀取 trainer 基本參數 accelerator: gpu devices: [0] max_epochs: 100 trainer_kwargs = dict(config.trainer) #1, 定義loss loger trainer_kwargs logger_cfg = { "target": "pytorch_lightning.loggers.TensorBoardLogger", "params": { "name": "tensorboard", "save_dir": logdir, } } logger_cfg = OmegaConf.create(logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) #2. callback modelckpt_params = config.call_back.modelckpt callbacks_cfg = { "setup_callback": { "target": "main_pl.SetupCallback", "params": { "now": now, "cfgdir": logdir, "config": config, } }, "learning_rate_logger": { "target": "main_pl.LearningRateMonitor", "params": { "logging_interval": "step", } }, "checkpoint_callback":{ "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": logdir, "filename": modelckpt_params.filename, "save_top_k": modelckpt_params.save_top_k, "verbose": modelckpt_params.verbose, "save_last": modelckpt_params.save_last, "every_n_epochs": modelckpt_params.every_n_epochs, } } } callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["callbacks"] = callbacks print(trainer_kwargs) #3. 構建trainer trainer = Trainer(**trainer_kwargs) #4. 構建data 和 model # build data and model data = instantiate_from_config(config.data) model = instantiate_from_config(config.model) #5. 訓練 try: trainer.fit(model, data) except Exception: raise
運行結果:
logs中會保存我們的配置和tensorboard的日志以及模型
5 總結
PyTorch Lightning 是一個輕量級的 PyTorch 封裝庫,它通過結構化代碼和自動化工程細節,顯著提升深度學習研究和開發的效率。以下是其主要優勢總結:
1. 代碼結構化與可讀性
- 關注科研而非工程:將模型定義、訓練邏輯、工程代碼解耦
- 標準化接口:強制使用
LightningModule
方法(training_step
,configure_optimizers
等) - 減少樣板代碼:訓練循環代碼量減少 80%+
# 傳統 PyTorch vs Lightning
# ---------------------------
# PyTorch: 需手動編寫訓練循環
for epoch in epochs:for batch in data:optimizer.zero_grad()loss = model(batch)loss.backward()optimizer.step()# Lightning: 只需定義邏輯
class LitModel(pl.LightningModule):def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)return loss
2. 自動化工程細節
功能 | 實現方式 | 優勢 |
---|---|---|
分布式訓練 | Trainer(accelerator="gpu", devices=4) | 單行代碼啟用多GPU/TPU |
混合精度訓練 | Trainer(precision="16-mixed") | 顯存節省+速度提升 |
梯度累積 | Trainer(accumulate_grad_batches=4) | 模擬更大batch_size |
早停機制 | callbacks=[EarlyStopping(...)] | 自動防止過擬合 |
3. 可復現性與實驗管理
- 版本控制:自動保存超參數 (
self.save_hyperparameters()
) - 實驗跟蹤:內置支持 TensorBoard/W&B/MLFlow
- 完整檢查點:一鍵保存模型+優化器+超參數狀態
# 自動記錄所有實驗
trainer = Trainer(logger=TensorBoardLogger("logs/"))
4. 硬件無關性
單行切換硬件環境:
# CPU → GPU → TPU → 多節點分布式
trainer = Trainer(accelerator="auto", # 自動檢測硬件devices="auto", # 使用所有可用設備strategy="ddp_find_unused_parameters_true" # 分布式策略
)
5. 調試與開發效率
# 快速驗證代碼
trainer = Trainer(fast_dev_run=True) # 只跑1個batch# 性能分析
trainer = Trainer(profiler="advanced") # 識別瓶頸