文章目錄
- 0 目的
- 1 TorchScript
- 1.1 語言特性的限定性
- 1.2 設計目的:模型表達的專注性
- 2pytorch JIT(Just-in-time compilation)
- 2.1pytorch JIT定義
- 2.1pytorch JIT整個過程:
- 1. 前端轉換層:生成靜態計算圖
- 2. 中間表示層(IR):靜態計算圖
- 3. 優化與編譯層
- 4. 執行層
- 3與pt文件的關系
- 3.1核心概念定義與層級關系
- 3.2pt文件構成與pth差異
0 目的
在部署時候需要靜態圖,靜態圖在pytorch中也稱為Script mode。為了script模式下的計算圖表達與優化,引入
兩個工具torchscript和pytorch JIT,一個工具將pytroch動態圖轉為靜態圖,另一個工具進行靜態圖優化。
1 TorchScript
TorchScript是Python靜態子集。TorchScript和完整Python語言之間最大的區別在于,TorchScript只支持表達神經網絡模型所需的一小部分靜態類型。TorchScript可以看成一種新的編程語言,設計的目的是為了脫離pytorch,python環境,作為python和其他語言(如c++)的一種中間橋接工具,方便部署。將pytorch寫的模型轉換成TorchScript語言后的代碼,稱為中間表示(IR),也稱為TorchScript IR,之前在計算圖中討論過,模型可以用計算圖(有向無環圖)表示,因此TorchScript IR也就是計算圖的中間表示。
1.1 語言特性的限定性
- 靜態類型約束:
TorchScript 僅支持 Python 中與神經網絡模型表達直接相關的有限類型(如Tensor
、Tuple
、int
、List
等),而舍棄了動態類型(如typing.Any
)和復雜控制流。
示例:變量必須聲明單一靜態類型,禁止運行時類型變更。 - 語法子集化:
僅保留if
/for
等基礎控制語句,且需符合靜態圖編譯要求(如循環次數需在編譯時可推斷)。其他 Python 特性(如動態類繼承、反射)被排除。
1.2 設計目的:模型表達的專注性
TorchScript 的語法設計聚焦于高效描述神經網絡計算圖,而非通用編程。這使其成為:
- 模型序列化載體:脫離 Python 環境后仍可完整表示計算邏輯。
- 編譯器友好格式:靜態類型便于優化器分析數據依賴與內存布局。
2pytorch JIT(Just-in-time compilation)
2.1pytorch JIT定義
JIT編譯器在模型運行時(而非訓練時)對代碼進行即時編譯與優化。在pytorch中JIT編譯器它不會將編譯的過程一口氣完成,而是先對代碼進行一些處理,存儲成某種序列化表示(比如計算圖);然后在實際的運行時環境中,通過 profiling 的方式,進行針對環境的優化并執行代碼。
pytorch JIT就是為了解決部署而誕生的工具。包括代碼的追蹤及解析、中間表示的生成、模型優化、序列化等各種功能,可以說是覆蓋了模型部署的方方面面。一方面使用TorchScript 作為python代碼的另一種表現形式,一方面對TorchScript IR進行優化。
其核心目標包括:
- 性能提升:通過算子融合、內存復用等優化手段加速推理(部分場景性能提升50%)。
- 部署解耦:脫離Python依賴,支持C++/移動端等非Python環境。
- 硬件適配:針對不同后端(CPU/GPU/TPU)生成優化機器碼。
2.1pytorch JIT整個過程:
TorchScript與PyTorch JIT的依賴關系
- TorchScript是JIT的前提:
動態圖模型必須首先轉換為TorchScript IR,才能被JIT編譯器優化。 - JIT賦予TorchScript執行能力:
TorchScript IR需通過JIT編譯為機器碼,否則僅為靜態數據結構。
1. 前端轉換層:生成靜態計算圖
轉換方式 | 原理 | 適用場景 |
---|---|---|
Tracing | 記錄示例輸入下的張量操作軌跡,生成IR(無法捕獲分支/循環) | 無控制流的簡單模型(如CNN) |
Scripting | 解析Python源碼,直接編譯為TorchScript(支持條件分支) | 含動態邏輯的模型(如RNN) |
-
1. 追蹤模式(Tracing)
-
原理:
輸入示例數據(如dummy_input
),記錄模型前向傳播的算子調用序列 → 生成線性計算圖dummy_input = torch.rand(1, 3, 224, 224) jit_model = torch.jit.trace(model, dummy_input) # 生成IR
-
局限性:
無法捕獲條件分支(如if x>0
)或循環(如for i in range(n)
),僅適合無控制流模型 -
2. 腳本模式(Scripting)
-
原理:
直接解析Python源碼 → 詞法分析(Lexer)→ 語法樹(AST)→ 語義分析 → 生成帶控制流的IRclass DynamicModel(nn.Module):def forward(self, x):if x.sum() > 0: return x * 2 # 分支邏輯可被保留else: return x / 2 jit_model = torch.jit.script(DynamicModel()) # 直接編譯源碼
2. 中間表示層(IR):靜態計算圖
IR數據結構
基于有向無環圖(DAG,這個詞在計算圖中出現過):
- Graph:頂級容器,表示一個函數(如
forward()
) - Block:基本塊(Basic Block),包含有序的Node序列
- Node:算子節點(如
aten::conv2d
),含輸入/輸出Value - Value:數據流邊,具有靜態類型(如
Tensor
/int
) - 屬性(Attributes) :存儲常量(如權重張量)
- 數據結構:,包含
Graph
(函數)、Node
(算子)、Value
(數據流)。 - 關鍵特性:
- 類型靜態化:所有變量需明確定義類型(如
Tensor
/int
),舍棄Python動態類型。 - 控制流顯式化:將
if
/for
轉換為圖節點(如prim::If
)。
- 類型靜態化:所有變量需明確定義類型(如
3. 優化與編譯層
- 圖優化(Graph Optimization):
- 算子融合:合并相鄰算子(如
conv2d + relu → conv2d_relu
),減少內核啟動開銷。 - 常量折疊:預計算靜態表達式(如
a=2; b=3; c=a*b
→c=6
)。
- 算子融合:合并相鄰算子(如
- 硬件后端適配:
- NVFuser:默認GPU優化器,針對NVIDIA顯卡生成高效CUDA核。
- CPU優化:利用OpenMP加速并行計算。
4. 執行層
- 輕量級解釋器:執行優化后的IR,無全局鎖(GIL),支持多線程并發。
- 運行時剖析(Profiling) :動態收集執行數據,反饋至編譯器迭代優化(如熱點代碼重編譯)。
3與pt文件的關系
將jit編譯優化后的模型進行保存,下一步就可以在C++上進行部署了。
序列化(Serialization) 是將程序中的對象(如模型參數、計算圖、張量等)轉換為可存儲或傳輸的標準化格式(如字節流、文件)的過程,而 反序列化(Deserialization) 則是將存儲的格式還原為內存中的對象。因此保存模型的save函數執行的就是序列化。
torch.jit.save('model.pt')
3.1核心概念定義與層級關系
PyTorch JIT保存的.pt
文件、PyTorch JIT編譯器與TorchScript三者構成模型部署的核心技術棧,其關系可通過以下分層架構與技術流程詳解:
組件 | 本質 | 角色定位 |
---|---|---|
TorchScript | Python的靜態類型子集(IR中間表示) | 模型表達層:定義可編譯的模型結構 |
PyTorch JIT | 運行時編譯器(Just-In-Time Compiler) | 優化執行層:將IR編譯為高效機器碼 |
.pt文件 | TorchScript模塊的序列化格式(ZIP歸檔) | 持久化層:存儲模型結構與參數 |
三者關系可概括為:
**TorchScript提供標準化模型表示 → PyTorch JIT進行運行時優化編譯 → .pt文件實現跨平臺持久化
3.2pt文件構成與pth差異
-
文件結構(ZIP歸檔格式):
model.pt ├── code/ # 優化后的TorchScript IR(計算圖) ├── data.pkl # 模型權重(張量數據) ├── constants.pkl # 嵌入的常量(如超參數) └── version # 格式版本號
-
序列化方法:
# 保存到磁盤文件 torch.jit.save(traced_model, "model.pt") # 或 traced_model.save("model.pt")# 保存到內存緩沖區(適用于網絡傳輸) buffer = io.BytesIO() torch.jit.save(traced_model, buffer)
pt文件 vs 普通PyTorch模型文件
特性 | JIT生成的.pt文件 | torch.save()保存的.pth文件 |
---|---|---|
內容 | 完整計算圖 + 參數 + 優化后的IR | 僅參數(state_dict)或Python類引用 |
可移植性 | 脫離Python環境(支持C++/移動端) | 依賴原始Python模型類定義 |
執行引擎 | JIT編譯器優化后的本地代碼 | Python解釋器執行 |
反編譯風險 | 代碼以IR存儲,難以還原原始Python邏輯 | 可直接查看模型類代碼 |
PyTorch JIT保存的.pt
文件、PyTorch JIT編譯器與TorchScript,三者協同構成PyTorch生產部署的核心基礎設施,覆蓋從研發到落地的完整生命周期。
參考
[1](TorchScript — PyTorch 2.7 documentation)
[3](PyTorch JIT and TorchScript. A path to production for PyTorch models | by Abhishek Sharma | TDS Archive | Medium)
[4]((8 封私信 / 5 條消息) TorchScript 解讀(一):初識 TorchScript - 知乎)
[5]((8 封私信 / 5 條消息) PyTorch系列「一」PyTorch JIT —— trace/ script的代碼組織和優化方法 - 知乎)
[6](PyTorch JIT | Chenglu’s Log)
[7](PyTorch Architecture | harleyszhang/llm_note | DeepWiki)
[8](TorchScript for Deployment — PyTorch Tutorials 2.7.0+cu126 documentation)
[9](Loading a TorchScript Model in C++ — PyTorch Tutorials 2.7.0+cu126 documentation)
[10](Introduction to TorchScript — PyTorch Tutorials 2.7.0+cu126 documentation)
[11]((8 封私信 / 5 條消息) 什么是torch.jit - 知乎)
[12]((8 封私信 / 5 條消息) 一文帶你使用即時編譯(JIT)提高 PyTorch 模型推理性能! - 知乎)
[13](TorchScript 解讀(二):Torch jit tracer 實現解析 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/489090393)
[14](Pytorch代碼部署:總結使用JIT將PyTorch模型轉換為TorchScript格式踩過的那些坑 - Ta沒有名字的文章 - 知乎
https://zhuanlan.zhihu.com/p/662228796)
[15](TorchScript JIT & IR - 靈丹的文章 - 知乎
https://zhuanlan.zhihu.com/p/543952666)
[16](TorchScript的簡介 - PyTorch官方教程中文版)
《深度學習編譯器設計第五章:中間表示》
《PRINCIPLED OPTIMIZATION OF DYNAMIC NEURAL NETWORKS》. JARED ROESCH