PyTorch JIT與TorchScript
轉自:https://zhuanlan.zhihu.com/p/370455320
如果搜索 PyTorch JIT,找到的將會是「TorchScript」的文檔,那么什么是 JIT 呢?JIT 和 TorchScript 又有什么聯系?
文章只會關注概念的部分,如果關注細節或實現部分,文章最后有一個完整的 Demo 可供參考。
什么是 JIT?
首先要知道 JIT 是一種概念,全稱是 Just In Time Compilation,中文譯為「即時編譯」,是一種程序優化的方法,一種常見的使用場景是「正則表達式」。例如,在 Python 中使用正則表達式:
prog = re.compile(pattern)
result = prog.match(string)
或
result = re.match(pattern, string)
上面兩個例子是直接從 Python 官方文檔中摘出來的 ,并且從文檔中可知,兩種寫法從結果上來說是「等價」的。但注意第一種寫法種,會先對正則表達式進行 compile,然后再進行使用。如果繼續閱讀 Python 的文檔,可以找到下面這段話:
using re.compile() and saving the resulting regular expression object for reuse is more efficient when the expression will be used several times in a single program.
也就是說,如果多次使用到某一個正則表達式,則建議先對其進行 compile,然后再通過 compile 之后得到的對象來做正則匹配。而這個 compile 的過程,就可以理解為 JIT(即時編譯)。
在深度學習中 JIT 的思想更是隨處可見,最明顯的例子就是 Keras 框架的 model.compile,TensorFlow 中的 Graph 也是一種 JIT,雖然他沒有顯示調用編譯方法。
那 PyTorch 呢?PyTorch 從面世以來一直以「易用性」著稱,最貼合原生 Python 的開發方式,這得益于 PyTorch 的「動態圖」結構。我們可以在 PyTorch 的模型前向中加任何 Python 的流程控制語句,甚至是下斷點單步跟進都不會有任何問題,但是如果是 TensorFlow,則需要使用 tf.cond 等 TensorFlow 自己開發的流程控制,誰更簡單一目了然。那么為什么 PyTorch 還需要引入 JIT 呢?
TorchScript
動態圖模型通過犧牲一些高級特性來換取易用性,那到底 JIT 有哪些特性,在什么情況下不得不用到 JIT 呢?下面主要通過介紹 TorchScript(PyTorch 的 JIT 實現)來分析 JIT 到底帶來了哪些好處。
- 模型部署
PyTorch 的 1.0 版本發布的最核心的兩個新特性就是 JIT 和 C++ API,這兩個特性一起發布不是沒有道理的,JIT 是 Python 和 C++ 的橋梁,我們可以使用 Python 訓練模型,然后通過 JIT 將模型轉為語言無關的模塊,從而讓 C++ 可以非常方便得調用,從此「使用 Python 訓練模型,使用 C++ 將模型部署到生產環境」對 PyTorch 來說成為了一件很容易的事。而因為使用了 C++,我們現在幾乎可以把 PyTorch 模型部署到任意平臺和設備上:樹莓派、iOS、Android 等等…
- 性能提升
既然是為部署生產所提供的特性,那免不了在性能上面做了極大的優化,如果推斷的場景對性能要求高,則可以考慮將模型(torch.nn.Module)轉換為 TorchScript Module,再進行推斷。
- 模型可視化
TensorFlow 或 Keras 對模型可視化工具(TensorBoard等)非常友好,因為本身就是靜態圖的編程模型,在模型定義好后整個模型的結構和正向邏輯就已經清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可視化上一直表現得不好,但 JIT 改善了這一情況。現在可以使用 JIT 的 trace 功能來得到 PyTorch 模型針對某一輸入的正向邏輯,通過正向邏輯可以得到模型大致的結構,但如果在 forward
方法中有很多條件控制語句,這依然不是一個好的方法,所以 PyTorch JIT 還提供了 Scripting 的方式,這兩種方式在下文中將詳細介紹。
TorchScript Module 的兩種生成方式
1. 編碼(Scripting)
可以直接使用 TorchScript Language 來定義一個 PyTorch JIT Module,然后用 torch.jit.script 來將他轉換成 TorchScript Module 并保存成文件。而 TorchScript Language 本身也是 Python 代碼,所以可以直接寫在 Python 文件中。
使用 TorchScript Language 就如同使用 TensorFlow 一樣,需要前定義好完整的圖。對于 TensorFlow 我們知道不能直接使用 Python 中的 if 等語句來做條件控制,而是需要用 tf.cond,但對于 TorchScript 我們依然能夠直接使用 if 和 for 等條件控制語句,所以即使是在靜態圖上,PyTorch 依然秉承了「易用」的特性。TorchScript Language 是靜態類型的 Python 子集,靜態類型也是用了 Python 3 的 typing 模塊來實現,所以寫 TorchScript Language 的體驗也跟 Python 一模一樣,只是某些 Python 特性無法使用(因為是子集),可以通過 TorchScript Language Reference 來查看和原生 Python 的異同。
理論上,使用 Scripting 的方式定義的 TorchScript Module 對模型可視化工具非常友好,因為已經提前定義了整個圖結構。
2. 追蹤(Tracing)
使用 TorchScript Module 的更簡單的辦法是使用 Tracing,Tracing 可以直接將 PyTorch 模型(torch.nn.Module)轉換成 TorchScript Module。「追蹤」顧名思義,就是需要提供一個「輸入」來讓模型 forward 一遍,以通過該輸入的流轉路徑,獲得圖的結構。這種方式對于 forward 邏輯簡單的模型來說非常實用,但如果 forward 里面本身夾雜了很多流程控制語句,則可能會有問題,因為同一個輸入不可能遍歷到所有的邏輯分枝。
此外,還可以混合使用上面兩種方式。
一個完整的例子
我簡單寫了一個簡單的 MNIST demo,從使用 Python 訓練到用 JIT 將 Python 模型轉換為 TorchScript Module,然后用 C++ 加載 TorchScript Module 做推斷的完整的過程:
https://github.com/louis-she/torchscript-mnist