pytorch lightning最簡上手
pytorch lightning 是對原生 pytorch 的通用模型開發過程進行封裝的一個工具庫。本文不會介紹它的高級功能,而是通過幾個最簡單的例子來幫助讀者快速理解、上手基本的使用方式。在掌握基礎 API 和使用方式之后,讀者可自行到 pytorch lightning 的官方文檔,了解進階 API。本文假設讀者對原生 pytorch 訓練腳本的搭建方法已經比較熟悉。
安裝
pytorch lighning 的安裝非常簡單,直接使用 pip 安裝即可:
pip install pytorch-lightning
最簡例子
pytorch lightning 有兩個最核心的 API:LigtningModule 和 Trainer 。
其中 LightningModule 是我們熟悉的 torch.nn.Module 的子類,可以通過
print(isinstance(pl.LightningModule(), torch.nn.Module))
來驗證。這意味著該類同樣需要實現 forward 方法,并可直接通過實例調用。
Trainer 則是開始執行模型訓練、測試過程的類,傳入一個 LightningModule 和對應控制參數來實例化即可開始訓練。
我們從一個最簡單的例子——MNIST 手寫數字識別開始:
1 導入必要的庫
導入 pytorch_lightning 和 pytorch 常用的庫。
import osimport torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
2 實現最簡LigntningModule
我們先實現一個最簡的 LightningModule。
-
__init__
構造函數中,像常見的 torch.nn.Module 一樣,我們定義好模型的層。由于是最簡實例,這里只有一層線性層,將手寫數字圖像映射為輸出 logits。
-
forward
由于是繼承自 torch.nn.Module,因此實現 forward 方法是必須的。forward 方法要完成模型的前向過程,這里直接調用 __init__ 中定義好的線性層,完成模型前向過程。
-
train_dataloader
train_dataloader 方法也是最簡實現中必須的,它的功能是獲取訓練集的 DataLoader。這里我們返回 MNIST 數據集的 DataLoader。dataloader 的獲取也可以不在類內實現,而是在 fit 時傳入,后面會介紹。
-
training_step
training_step 是是 LigtningModule 的核心方法,它定義了一個訓練步中需要做的事情。在深度學習的訓練步中,最核心的事情就是模型前向,得到結果,計算損失,反向傳播,更新參數,這幾步在 pytorch 中都有對應的方法供調用。但是在 pytorch lightning 中,我們只需要進行模型前向,并返回必要的信息即可。在最簡實現中,我們只需返回損失。
-
configure_optimizer
在 training_step 中,我們只需返回損失,這意味著模型的反向傳播和參數更新過程由 pytorch lightning 幫我們完成了。雖然這個過程可以有框架自己完成,但是我們還是要指定參數更新所用的優化器,在很多模型中,優化器、學習率等超參數設置對結果影響很大。在最簡實現中,我們設置好學習率,并返回一個 Adam 優化器。
class MNISTModel(pl.LightningModule):def __init__(self):super(MNISTModel, self).__init__()self.l1 = torch.nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.l1(x.view(x.size(0), -1)))def train_dataloader(self):return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def training_step(self, batch, batch_nb):x, y = batchloss = F.cross_entropy(self(x), y)return lossdef configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.02)
以上我們實現 training_step,train_dataloader, configure_optimizer,已經是最簡單的 LightningModule 的實現了。如果連這三個方法都沒有實現的話,將會報錯:
No `xxx` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined
3 開始訓練
在實現好 LightningModule 之后,就可以開始訓練了。
啟動訓練的最簡實現非常簡單,只需三行:實例化模型、實例化訓練器、開始訓練!
model = MNISTModel()
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(model)
開始訓練后,pytorch lightning 會打印出可用設備、模型參數等豐富的信息。
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]| Name | Type | Params
--------------------------------
0 | l1 | Linear | 7.9 K
--------------------------------
7.9 K Trainable params
0 Non-trainable params
7.9 K Total params
0.031 Total estimated model params size (MB)
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 261.53it/s, loss=1.3, v_num=10]
總結
以上我們用 30 行左右代碼,實現了一個最簡的 pytorch lightning 訓練過程。這足以體現出 pytorch lightning 的簡潔、易用。但是,顯然這個最簡實現缺少了很多東西,比如驗證、測試、日志打印、模型保存等。接下來,我們將實現相對完整但依舊簡潔的 pytorch lightning 模型開發過程。
pytorch lightning更多功能
本節將介紹相對更完整的 pytorch lightning 模型開發過程。
LighningModeul需實現方法
在一個相對完整的 LightnintModule 中,用戶應當實現以下方法:
1 模型定義 (__init__)
通常定義模型的各個層,在 forward 調用這些層,完成模型前向。與原生 pytorch 類似。
2 前向計算 (forward)
與 torch.nn.Module 的 forward 中做的事情一樣,調用 _init_ 中定義的層。完成模型前向。與原生 pytorch 類似。
3 訓練/驗證/測試步 (training_step/validation_step/test_step)
定義訓練/測試/訓練每一步中要做的事情,一般是計算損失、指標并返回。
def training_step(self, batch, batch_idx):# ....return xxx # 如果是training_step, 則必須包含損失
通常有兩個入參 batch 和 batch_idx。是 batch 是 dataloader 給出的輸入數據和標簽,batch_idx 是當前 batch 的索引。
注意訓練步的返回值必須是損失值,或者是包含 ‘loss’ 字段的字典。驗證/測試步的返回值不必包括損失,可以是任意結果。
4 訓練/驗證/測試步結束后 (training_step_end/validation_step_end/test_step_end)
只在使用多個node進行訓練且結果涉及如softmax之類需要全部輸出聯合運算的步驟時使用該函數。
5 訓練/驗證/測試輪結束后 (training_epoch_end/validation_epoch_end/test_epoch_end)
以 training_epoch_end 為例,其他類似。
如果需要對整一輪的結果進行處理,比如計算一些平均指標等,可以通過 training_epoch_end 來實現。
def training_epoch_end(self, outputs):# ....return xxx
其中入參 outputs 是一個列表,包含了每一步 training_step 返回的內容。我們可以在每一輪結束后,對每一步的結果進行處理。
4 選用優化器 (configure_optimizers)
設置模型參數更新所用的優化器。值得一提的是如果需要多個優化器(比如在訓練 GAN 時),可以返回優化器列表。也可以在優化器的基礎上返回學習率調整器,那就要返回兩個列表。
5 數據加載器 (train_dataloader, val_dataloader, test_dataloader)
返回 dataloader。
各個 dataloader 也可以在運行 fit/validation/test 時傳入,如:
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
model = MNISTModel() # 不需要實現get_dataloader方法
trainer.fit(model, train_loader)
LightningModule自帶工具
LightningModule 中提供了一些常用工具供用戶直接使用:
log
Tensorboard 損失/指標日志保存和查看,不要自己定義,直接用即可。用法非常簡單,將要記錄的值傳入:
self.log('train loss', loss)
當然一個功能完整的日志保存接口肯定提供了很多參數來控制,比如是按照 epoch 記錄還是按照 step 記錄、多卡訓練時如何同步、指標是否要展示在進度條上、指標是否要保存在日志文件中等等。pytorch lightning 為這些選項都提供了控制參數,讀者可以參考官方文檔中 log 相關部分。
python 自帶的 print 函數在進行多進程訓練時會在每個進程都打印內容,這是原生 pytorch 進行分布式訓練時一個很小但是很頭疼的問題。LightningModule 提供的 print 只打印一次。
freeze
凍結所有權重以供預測時候使用。僅當已經訓練完成且后面只測試時使用。
Trainer實例化參數
在實例化 Trainer 時,pytorch lightning 也提供了很多控制參數,這里介紹常用的幾個,完整參數及含義請參考官方文檔中 Trainer 相關部分。
- default_root_dir:默認存儲地址。所有的實驗變量和權重全部會被存到這個文件夾里面。默認情況下,實驗結果會存在
lightning_logs/version_x/
。 - max_epochs:最大訓練周期數,默認為 1000,如果不設上限 epoch 數,設置為 -1。
- auto_scale_batch_size:在進行訓練前自動選擇合適的batch size。
- auto_select_gpus:自動選擇合適的GPU。尤其是在有GPU處于獨占模式時候,非常有用。
- gpus:控制使用的GPU數。當設定為None時,使用 cpu。
- auto_lr_find:自動找到合適的初始學習率。使用了該論文的技術。當且僅當執行
trainer.tune(model)
代碼時工作。 - precision:浮點數精度。默認 32,即常規單精度 fp32 旬來呢。指定為 16 可以使用 fp16 精度加快模型訓練并減少顯存占用。
- val_check_interval:進行驗證的周期。默認為 1,如果要訓練 10 個 epoch 進行一次驗證,設置為 10。
- fast_dev_run:如果設定為true,會只執行一個 batch 的 train, val 和 test,然后結束。僅用于debug。
- callbacks:需要調用的 callback 函數列表,關于常用 callback 函數下面會介紹。
- …
callback函數
Callback 是一個自包含的程序,可以與訓練流程交織在一起,而不會污染主要的研究邏輯。Callback 并不一定只能在 epoch 結尾調用。pytorch-lightning 提供了數十個hook(接口,調用位置)可供選擇,也可以自定義callback,實現任何想實現的模塊。
推薦使用方式是,隨問題和項目變化的操作,實現到 lightning module里面。而獨立的、可復用的內容則可以定義單獨的模塊,方便多個模型調用。
常見的內建 callback 如:EarlyStopping,根據某個值,在數個epoch沒有提升的情況下提前停止訓練。。PrintTableMetricsCallback,在每個epoch結束后打印一份結果整理表格等。更多內建 callbacks 可參考相關文檔。
模型加載與保存
模型保存
ModelCheckpoint 是一個自動儲存的 callback 模塊。默認情況下訓練過程中只會自動儲存最新的模型與相關參數,而用戶可以通過這個 module 自定義。如觀測一個 val_loss
的值,并儲存 top 3 好的模型,且同時儲存最后一個 epoch 的模型,等等。例:
from pytorch_lightning.callbacks import ModelCheckpoint# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(monitor='val_loss',filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',save_last=True
)trainer = pl.Trainer(gpus=1, max_epochs=3, callbacks=[checkpoint_callback])
ModelCheckpoint Callback中,如果 save_weights_only=True
,那么將會只儲存模型的權重,相當于 model.save_weights(filepath)
,反之會儲存整個模型(包括模型結構),相當于model.save(filepath)
)。
另外,也可以手動存儲checkpoint: trainer.save_checkpoint("example.ckpt")
模型加載
加載一個模型,包括它的模型權重和超參數:
model = MyLightingModule.load_from_checkpoint(PATH)print(model.learning_rate)
# 打印出超參數model.eval()
y_hat = model(x)
加載模型時替換一些超參數:
class LitModel(LightningModule):def __init__(self, in_dim, out_dim):super().__init__()self.save_hyperparameters()self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)# 如果在訓練和保存模型時,超參數設置如下,在加載后可以替換這些超參數。
LitModel(in_dim=32, out_dim=10)# 仍然使用in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)# 替換為in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
完整加載訓練狀態,包括模型的一切,以及和訓練相關的一切參數,如 model, epoch, step, LR schedulers, apex 等。
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')# 自動恢復 model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)
實例
基于第三節介紹的更多功能,我們擴展第二節 MNIST 訓練程序。代碼如下。
import osimport torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as npclass MNISTModel(pl.LightningModule):def __init__(self):super().__init__()self.fc = nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.fc(x.view(-1, 28 * 28)))def training_step(self, batch, batch_nb):# REQUIREDx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss, on_step=False, on_epoch=True)return {'loss': loss}def validation_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)pred = y_hat.argmax(dim=1, keepdim=True)correct = pred.eq(y.view_as(pred)).sum().item()acc = correct / x.shape[0]self.log('val_acc', acc, on_step=False, on_epoch=True)self.log('val_loss', loss, on_step=False, on_epoch=True)return {'val_loss': loss, 'val_acc': acc}def validation_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()avg_acc = np.mean([x['val_acc'] for x in outputs])return {'val_loss': avg_loss, 'val_acc': avg_acc}def test_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)return {'test_loss': loss}def test_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()return {'test_loss': avg_loss}def configure_optimizers(self):# REQUIREDreturn torch.optim.Adam(self.parameters(), lr=0.02)def train_dataloader(self):# REQUIREDreturn DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def val_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)def test_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=10,callbacks=[pl.callbacks.EarlyStopping( monitor="val_loss", patience=3),]
)
trainer.fit(model)
trainer.test()
Ref
- pytorch lightning 的官方文檔
- Pytorch Lightning 完全攻略
- 參考代碼