pytorch lightning最簡上手

pytorch lightning最簡上手

pytorch lightning 是對原生 pytorch 的通用模型開發過程進行封裝的一個工具庫。本文不會介紹它的高級功能,而是通過幾個最簡單的例子來幫助讀者快速理解、上手基本的使用方式。在掌握基礎 API 和使用方式之后,讀者可自行到 pytorch lightning 的官方文檔,了解進階 API。本文假設讀者對原生 pytorch 訓練腳本的搭建方法已經比較熟悉。

安裝

pytorch lighning 的安裝非常簡單,直接使用 pip 安裝即可:

pip install pytorch-lightning

最簡例子

pytorch lightning 有兩個最核心的 API:LigtningModuleTrainer

其中 LightningModule 是我們熟悉的 torch.nn.Module 的子類,可以通過

print(isinstance(pl.LightningModule(), torch.nn.Module))

來驗證。這意味著該類同樣需要實現 forward 方法,并可直接通過實例調用。

Trainer 則是開始執行模型訓練、測試過程的類,傳入一個 LightningModule 和對應控制參數來實例化即可開始訓練。

我們從一個最簡單的例子——MNIST 手寫數字識別開始:

1 導入必要的庫

導入 pytorch_lightning 和 pytorch 常用的庫。

import osimport torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

2 實現最簡LigntningModule

我們先實現一個最簡的 LightningModule。

  • __init__

    構造函數中,像常見的 torch.nn.Module 一樣,我們定義好模型的層。由于是最簡實例,這里只有一層線性層,將手寫數字圖像映射為輸出 logits。

  • forward

    由于是繼承自 torch.nn.Module,因此實現 forward 方法是必須的。forward 方法要完成模型的前向過程,這里直接調用 __init__ 中定義好的線性層,完成模型前向過程。

  • train_dataloader

    train_dataloader 方法也是最簡實現中必須的,它的功能是獲取訓練集的 DataLoader。這里我們返回 MNIST 數據集的 DataLoader。dataloader 的獲取也可以不在類內實現,而是在 fit 時傳入,后面會介紹。

  • training_step

    training_step 是是 LigtningModule 的核心方法,它定義了一個訓練步中需要做的事情。在深度學習的訓練步中,最核心的事情就是模型前向,得到結果,計算損失,反向傳播,更新參數,這幾步在 pytorch 中都有對應的方法供調用。但是在 pytorch lightning 中,我們只需要進行模型前向,并返回必要的信息即可。在最簡實現中,我們只需返回損失。

  • configure_optimizer

    在 training_step 中,我們只需返回損失,這意味著模型的反向傳播和參數更新過程由 pytorch lightning 幫我們完成了。雖然這個過程可以有框架自己完成,但是我們還是要指定參數更新所用的優化器,在很多模型中,優化器、學習率等超參數設置對結果影響很大。在最簡實現中,我們設置好學習率,并返回一個 Adam 優化器。

class MNISTModel(pl.LightningModule):def __init__(self):super(MNISTModel, self).__init__()self.l1 = torch.nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.l1(x.view(x.size(0), -1)))def train_dataloader(self):return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def training_step(self, batch, batch_nb):x, y = batchloss = F.cross_entropy(self(x), y)return lossdef configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.02)

以上我們實現 training_step,train_dataloader, configure_optimizer,已經是最簡單的 LightningModule 的實現了。如果連這三個方法都沒有實現的話,將會報錯:

 No `xxx` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined

3 開始訓練

在實現好 LightningModule 之后,就可以開始訓練了。

啟動訓練的最簡實現非常簡單,只需三行:實例化模型、實例化訓練器、開始訓練!

model = MNISTModel()
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(model)

開始訓練后,pytorch lightning 會打印出可用設備、模型參數等豐富的信息。

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]| Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 261.53it/s, loss=1.3, v_num=10]

總結

以上我們用 30 行左右代碼,實現了一個最簡的 pytorch lightning 訓練過程。這足以體現出 pytorch lightning 的簡潔、易用。但是,顯然這個最簡實現缺少了很多東西,比如驗證、測試、日志打印、模型保存等。接下來,我們將實現相對完整但依舊簡潔的 pytorch lightning 模型開發過程。

pytorch lightning更多功能

本節將介紹相對更完整的 pytorch lightning 模型開發過程。

LighningModeul需實現方法

在一個相對完整的 LightnintModule 中,用戶應當實現以下方法:

1 模型定義 (__init__)

通常定義模型的各個層,在 forward 調用這些層,完成模型前向。與原生 pytorch 類似。

2 前向計算 (forward)

與 torch.nn.Module 的 forward 中做的事情一樣,調用 _init_ 中定義的層。完成模型前向。與原生 pytorch 類似。

3 訓練/驗證/測試步 (training_step/validation_step/test_step)

定義訓練/測試/訓練每一步中要做的事情,一般是計算損失、指標并返回。

def training_step(self, batch, batch_idx):# ....return xxx # 如果是training_step, 則必須包含損失

通常有兩個入參 batch 和 batch_idx。是 batch 是 dataloader 給出的輸入數據和標簽,batch_idx 是當前 batch 的索引。

注意訓練步的返回值必須是損失值,或者是包含 ‘loss’ 字段的字典。驗證/測試步的返回值不必包括損失,可以是任意結果。

4 訓練/驗證/測試步結束后 (training_step_end/validation_step_end/test_step_end)

只在使用多個node進行訓練且結果涉及如softmax之類需要全部輸出聯合運算的步驟時使用該函數。

5 訓練/驗證/測試輪結束后 (training_epoch_end/validation_epoch_end/test_epoch_end)

以 training_epoch_end 為例,其他類似。

如果需要對整一輪的結果進行處理,比如計算一些平均指標等,可以通過 training_epoch_end 來實現。

def training_epoch_end(self, outputs):# ....return xxx

其中入參 outputs 是一個列表,包含了每一步 training_step 返回的內容。我們可以在每一輪結束后,對每一步的結果進行處理。

4 選用優化器 (configure_optimizers)

設置模型參數更新所用的優化器。值得一提的是如果需要多個優化器(比如在訓練 GAN 時),可以返回優化器列表。也可以在優化器的基礎上返回學習率調整器,那就要返回兩個列表。

5 數據加載器 (train_dataloader, val_dataloader, test_dataloader)

返回 dataloader。

各個 dataloader 也可以在運行 fit/validation/test 時傳入,如:

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
model = MNISTModel()		# 不需要實現get_dataloader方法
trainer.fit(model, train_loader)

LightningModule自帶工具

LightningModule 中提供了一些常用工具供用戶直接使用:

log

Tensorboard 損失/指標日志保存和查看,不要自己定義,直接用即可。用法非常簡單,將要記錄的值傳入:

self.log('train loss', loss)

當然一個功能完整的日志保存接口肯定提供了很多參數來控制,比如是按照 epoch 記錄還是按照 step 記錄、多卡訓練時如何同步、指標是否要展示在進度條上、指標是否要保存在日志文件中等等。pytorch lightning 為這些選項都提供了控制參數,讀者可以參考官方文檔中 log 相關部分。

print

python 自帶的 print 函數在進行多進程訓練時會在每個進程都打印內容,這是原生 pytorch 進行分布式訓練時一個很小但是很頭疼的問題。LightningModule 提供的 print 只打印一次。

freeze

凍結所有權重以供預測時候使用。僅當已經訓練完成且后面只測試時使用。

Trainer實例化參數

在實例化 Trainer 時,pytorch lightning 也提供了很多控制參數,這里介紹常用的幾個,完整參數及含義請參考官方文檔中 Trainer 相關部分。

  • default_root_dir:默認存儲地址。所有的實驗變量和權重全部會被存到這個文件夾里面。默認情況下,實驗結果會存在 lightning_logs/version_x/
  • max_epochs:最大訓練周期數,默認為 1000,如果不設上限 epoch 數,設置為 -1。
  • auto_scale_batch_size:在進行訓練前自動選擇合適的batch size。
  • auto_select_gpus:自動選擇合適的GPU。尤其是在有GPU處于獨占模式時候,非常有用。
  • gpus:控制使用的GPU數。當設定為None時,使用 cpu。
  • auto_lr_find:自動找到合適的初始學習率。使用了該論文的技術。當且僅當執行 trainer.tune(model) 代碼時工作。
  • precision:浮點數精度。默認 32,即常規單精度 fp32 旬來呢。指定為 16 可以使用 fp16 精度加快模型訓練并減少顯存占用。
  • val_check_interval:進行驗證的周期。默認為 1,如果要訓練 10 個 epoch 進行一次驗證,設置為 10。
  • fast_dev_run:如果設定為true,會只執行一個 batch 的 train, val 和 test,然后結束。僅用于debug。
  • callbacks:需要調用的 callback 函數列表,關于常用 callback 函數下面會介紹。

callback函數

Callback 是一個自包含的程序,可以與訓練流程交織在一起,而不會污染主要的研究邏輯。Callback 并不一定只能在 epoch 結尾調用。pytorch-lightning 提供了數十個hook(接口,調用位置)可供選擇,也可以自定義callback,實現任何想實現的模塊。

推薦使用方式是,隨問題和項目變化的操作,實現到 lightning module里面。而獨立的、可復用的內容則可以定義單獨的模塊,方便多個模型調用。

常見的內建 callback 如:EarlyStopping,根據某個值,在數個epoch沒有提升的情況下提前停止訓練。。PrintTableMetricsCallback,在每個epoch結束后打印一份結果整理表格等。更多內建 callbacks 可參考相關文檔。

模型加載與保存

模型保存

ModelCheckpoint 是一個自動儲存的 callback 模塊。默認情況下訓練過程中只會自動儲存最新的模型與相關參數,而用戶可以通過這個 module 自定義。如觀測一個 val_loss 的值,并儲存 top 3 好的模型,且同時儲存最后一個 epoch 的模型,等等。例:

from pytorch_lightning.callbacks import ModelCheckpoint# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(monitor='val_loss',filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',save_last=True
)trainer = pl.Trainer(gpus=1, max_epochs=3, callbacks=[checkpoint_callback])

ModelCheckpoint Callback中,如果 save_weights_only=True,那么將會只儲存模型的權重,相當于 model.save_weights(filepath),反之會儲存整個模型(包括模型結構),相當于model.save(filepath))。

另外,也可以手動存儲checkpoint: trainer.save_checkpoint("example.ckpt")

模型加載

加載一個模型,包括它的模型權重和超參數:

model = MyLightingModule.load_from_checkpoint(PATH)print(model.learning_rate)
# 打印出超參數model.eval()
y_hat = model(x)

加載模型時替換一些超參數:

class LitModel(LightningModule):def __init__(self, in_dim, out_dim):super().__init__()self.save_hyperparameters()self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)# 如果在訓練和保存模型時,超參數設置如下,在加載后可以替換這些超參數。
LitModel(in_dim=32, out_dim=10)# 仍然使用in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)# 替換為in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

完整加載訓練狀態,包括模型的一切,以及和訓練相關的一切參數,如 model, epoch, step, LR schedulers, apex 等。

model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')# 自動恢復 model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)

實例

基于第三節介紹的更多功能,我們擴展第二節 MNIST 訓練程序。代碼如下。

import osimport torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as npclass MNISTModel(pl.LightningModule):def __init__(self):super().__init__()self.fc = nn.Linear(28 * 28, 10)def forward(self, x):return torch.relu(self.fc(x.view(-1, 28 * 28)))def training_step(self, batch, batch_nb):# REQUIREDx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss, on_step=False, on_epoch=True)return {'loss': loss}def validation_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)pred = y_hat.argmax(dim=1, keepdim=True)correct = pred.eq(y.view_as(pred)).sum().item()acc = correct / x.shape[0]self.log('val_acc', acc, on_step=False, on_epoch=True)self.log('val_loss', loss, on_step=False, on_epoch=True)return {'val_loss': loss, 'val_acc': acc}def validation_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()avg_acc = np.mean([x['val_acc'] for x in outputs])return {'val_loss': avg_loss, 'val_acc': avg_acc}def test_step(self, batch, batch_nb):# OPTIONALx, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)return {'test_loss': loss}def test_epoch_end(self, outputs):# OPTIONALavg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()return {'test_loss': avg_loss}def configure_optimizers(self):# REQUIREDreturn torch.optim.Adam(self.parameters(), lr=0.02)def train_dataloader(self):# REQUIREDreturn DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)def val_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)def test_dataloader(self):# OPTIONALreturn DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)model = MNISTModel()
trainer = pl.Trainer(gpus=1,max_epochs=10,callbacks=[pl.callbacks.EarlyStopping( monitor="val_loss", patience=3),]
)
trainer.fit(model)
trainer.test()

Ref

  • pytorch lightning 的官方文檔
  • Pytorch Lightning 完全攻略
  • 參考代碼

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532367.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532367.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532367.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

RT-Smart 官方 ARM 32 平臺 musl gcc 工具鏈下載

前言 RT-Smart 的開發離不開 musl gcc 工具鏈&#xff0c;用于編譯 RT-Smart 內核與用戶態應用程序 RT-Smart musl gcc 工具鏈代碼當前未開源&#xff0c;但可以下載到 RT-Thread 官方編譯好的最新的 musl gcc 工具鏈 ARM 32位 平臺 比如 RT-Smart 最好用的 ARM32 位 qemu 平…

java list翻轉_JAVA實現兩種方法反轉單列表

/***authorluochengcheng* 定義一個單鏈表*/classNode {//變量private intrecord;//指向下一個對象privateNode nextNode;public Node(intrecord) {super();this.record record;}public intgetRecord() {returnrecord;}public void setRecord(intrecord) {this.record record;}…

OpenAI Whisper論文筆記

OpenAI Whisper論文筆記 OpenAI 收集了 68 萬小時的有標簽的語音數據&#xff0c;通過多任務、多語言的方式訓練了一個 seq2seq &#xff08;語音到文本&#xff09;的 Transformer 模型&#xff0c;自動語音識別&#xff08;ASR&#xff09;能力達到商用水準。本文為李沐老師…

mysql 工具 08s01_Mysql管理必備工具Maatkit詳解之十四(mk-kill)

mk-kill - 顧名思義&#xff0c;殺mysql線程。安裝方法查看這里。在一個OLTP的生產環境&#xff0c;一般不會讓sql執行過長的時間&#xff0c;特別是myisam這樣表鎖的引擎&#xff0c;如果出現長時間執行的sql一般是誤操作&#xff0c;要不就是出現問題了。出現這種情況&#x…

【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作

【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作 轉自&#xff1a;【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作 作者&#xff1a;潘小小 知識蒸餾是一種模型壓縮方法&#xff0c;是一種基于“教師-學生網絡思想”的訓練方法&#xff0c;由于其簡單&#xf…

深度學習三大謎團:集成、知識蒸餾和自蒸餾

深度學習三大謎團&#xff1a;集成、知識蒸餾和自蒸餾 轉自&#xff1a;https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA 原文&#xff08;英&#xff09;&#xff1a;https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge…

在墻上找垂直線_墻上如何快速找水平線

在裝修房子的時候&#xff0c;墻面的面積一般都很大&#xff0c;所以在施工的時候要找準水平線很重要&#xff0c;那么一般施工人員是如何在墻上快速找水平線的呢&#xff1f;今天小編就來告訴大家幾種找水平線的方法。一、如何快速找水平線1、用一根透明的軟管&#xff0c;長度…

百度地圖mysql打點_關于百度地圖連接MYSQL的問題,謝謝啦!

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓大家好&#xff0c;剛使用百度地圖API&#xff0c;請教大家一個問題&#xff0c;謝啦&#xff01;我需要從我的數據庫中取出字段為"city"的所有數據&#xff0c;然后通過bdGEO()函數在地圖上標注這些城市&#xff0c;我是…

PyTorch中的torch.nn.Parameter() 詳解

PyTorch中的torch.nn.Parameter() 詳解 今天來聊一下PyTorch中的torch.nn.Parameter()這個函數&#xff0c;筆者第一次見的時候也是大概能理解函數的用途&#xff0c;但是具體實現原理細節也是云里霧里&#xff0c;在參考了幾篇博文&#xff0c;做過幾個實驗之后算是清晰了&am…

Vision Transformer(ViT)PyTorch代碼全解析(附圖解)

Vision Transformer&#xff08;ViT&#xff09;PyTorch代碼全解析 最近CV領域的Vision Transformer將在NLP領域的Transormer結果借鑒過來&#xff0c;屠殺了各大CV榜單。本文將根據最原始的Vision Transformer論文&#xff0c;及其PyTorch實現&#xff0c;將整個ViT的代碼做一…

hdfs的副本數為啥增加了_HDFS詳解之塊大小和副本數

1.HDFSHDFS : 偽分布式(學習)NNDNSNNsbin/start-dfs.sh(開啟hdfs使用的腳本)bin/hdfs dfs -ls (輸入命令加前綴bin/hdfs dfs)2.block(塊)dfs.blocksize &#xff1a; 134217728(字節) / 128M 官網默認一個塊的大小128M*舉例理解塊1個文件 130M&#xff0c;默認一個塊的大小128M…

Linux下的ELF文件、鏈接、加載與庫(含大量圖文解析及例程)

Linux下的ELF文件、鏈接、加載與庫 鏈接是將將各種代碼和數據片段收集并組合為一個單一文件的過程&#xff0c;這個文件可以被加載到內存并執行。鏈接可以執行與編譯時&#xff0c;也就是在源代碼被翻譯成機器代碼時&#xff1b;也可以執行于加載時&#xff0c;也就是被加載器加…

mysql gender_Mysql第一彈

1、創建數據庫pythoncreate database python charsetutf8;2、設計班級表結構為id、name、isdelete&#xff0c;編寫創建表的語句create table classes(id int unsigned auto_increment primary key not null,name varchar(10),isdelete bit default 0);向班級表中插入數據pytho…

python virtualenv nginx_Ubuntu下搭建Nginx+supervisor+pypy+virtualenv

系統&#xff1a;Ubuntu 14.04 LTS搭建python的運行環境&#xff1a;NginxSupervisorPypyVirtualenv軟件說明&#xff1a;Nginx&#xff1a;通過upstream進行負載均衡Supervisor&#xff1a;管理python進程Pypy&#xff1a;用Python實現的Python解釋器PyPy is a fast, complian…

如何設置mysql表中文亂碼_php mysql表中文亂碼問題如何解決

為避免mysql中出現中文亂碼&#xff0c;建議在創建數據庫時指定編碼格式&#xff1a;復制代碼 代碼示例:create database zzjz CHARACTER SET gbk COLLATE gbk_chinese_ci;create table zz_employees (employeeid int unsigned not null auto_increment primary key,name varch…

java 按鈕 監聽_Button的四種監聽方式

Button按鈕設置點擊的四種監聽方式注&#xff1a;加粗放大的都是改變的代碼1.使用匿名內部類的形式進行設置使用匿名內部類的形式&#xff0c;直接將需要設置的onClickListener接口對象初始化&#xff0c;內部的onClick方法會在按鈕被點擊的時候執行第一個活動的java代碼&#…

java int轉bitmap_Java Base64位編碼與String字符串的相互轉換,Base64與Bitmap的相互轉換實例代碼...

首先是網上大神給的類package com.duanlian.daimengmusic.utils;public final class Base64Util {private static final int BASELENGTH 128;private static final int LOOKUPLENGTH 64;private static final int TWENTYFOURBITGROUP 24;private static final int EIGHTBIT …

linux查看java虛擬機內存_深入理解java虛擬機(linux與jvm內存關系)

本文轉載自美團技術團隊發表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux與進程內存模型要理解jvm最重要的一點是要知道jvm只是linux的一個進程,把jvm的視野放大,就能很好的理解JVM細分的一些概念下圖給出了硬件系統進程三個層面內存之間的關系.從硬件上…

java 循環stringbuffer_java常用類-----StringBuilder和StringBuffer的用法

一、可變字符常用方法package cn.zxg.PackgeUse;/*** 測試StringBuilder,StringBuffer可變字符序列常用方法*/public class TestStringBuilder2 {public static void main(String[] args) {StringBuilder sbnew StringBuilder();for(int i0;i<26;i){char temp(char)(ai);sb.…

java function void_Java8中你可能不知道的一些地方之函數式接口實戰

什么時候可以使用 Lambda&#xff1f;通常 Lambda 表達式是用在函數式接口上使用的。從 Java8 開始引入了函數式接口&#xff0c;其說明比較簡單&#xff1a;函數式接口(Functional Interface)就是一個有且僅有一個抽象方法&#xff0c;但是可以有多個非抽象方法的接口。 java8…