一、引言
The Torch-MLIR project provides core infrastructure for bridging the PyTorch ecosystem and the MLIR ecosystem. For example, Torch-MLIR enables PyTorch models to be lowered to a few different MLIR dialects. Torch-MLIR does not attempt to provide a production end-to-end flow for PyTorch programs by itself, but is a useful component for constructing one.
這是torch-mlir官方的介紹。總結下來就是 torch-mlir 提供了一個方案,用來連接 pytorch 生態和 mlir 生態。
二、torch.compile
下圖是pytorch 官方介紹 PyTorch 2.0 版本提供的pytorch 2.x 之后的編譯過程,即torch.compile。
圖的原文
可以發現torch.compile主要實現兩個過程:前端和后端
1、前端獲取計算圖輸出中間表示,供后端生成二進制可執行文件。主要通過 TorchDynamo+AOT Autograd 輸出中間表示 FX Graph。
其中TorchDynamo處理后,得到的是可以帶控制邏輯的高層計算圖,再經過AOT Autograd 處理就得到了一種通用的aten ir 中間表示,是不含控制邏輯底層計算圖,aten ir 可以等價于 pytorch 中定義的prim ops + aten ops。
2、后端翻譯生成二進制文件。pytorch 提供的一個工具是 TorchInductor,這里跳過,不是本文討論的重點。
而 torch-mlir 主要實現的部分是緊接著 aten ir,可以將 aten ir 繼續下降為 mlir,即接入 MLIR生態。
三、Torch-MLIR
torch-mlir 的實現也可以分為前端和后端。前后端的中間交互是 torch-mlir 自己定義的一個 dialect :Torch dialect。而 backend contract是Torch dialect的子集,既可以向上兼容 aten ir,又可以向下對接mlir 生態中的目標 dialect。
-
torch-mlir 的前端與pytorch 本身的接口有關,torch-mlir 封裝 pytorch 的 API 為 backend contract。 即也是將python program 下降為 aten ir。
-
torch-mlir 的后端是將 backend contract 下降為 mlir 生態中的目標 dialect。如 Linalg、TOSA、MHLO等。
下圖為torch-mlir的官方提供的架構圖
目前torch-mlir 項目有兩個主要的 API :torch_mlir.torchscript.compile 和 torch_mlir.fx.export_and_import。 -
第一條路徑是舊項目 pt1 代碼的一部分 (torch_mlir.torchscript.compile),允許用戶測試編譯器的 輸出到不同的 MLIR 方言
-
第二條路徑 (torch_mlir.fx.export_and_import)允許用戶導入任意 Python 可調用對象(nn.Module、函數或方法)的合并 torch.export.ExportedProgram 實例,并輸出到 torch dialect mlir 模塊。 該路徑與 PyTorch 的路線圖一致。
from torch_mlir import fx
from torch_mlir.compiler_utils import run_pipeline_with_repro_reportclass SimpleModel(torch.nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = torch.nn.Linear(10, 10)self.relu = torch.nn.ReLU()def forward(self, x):x = self.linear(x)x = self.relu(x)return x# front
aten_ir = fx.export_and_import(Basic(), torch.randn(3, 4))
print(aten_ir)# backend to linalg
run_pipeline_with_repro_report(aten_ir,("builtin.module(""func.func(torch-decompose-complex-ops),""torch-backend-to-linalg-on-tensors-backend-pipeline)"),"Lowering TorchFX IR -> Linalg IR",enable_ir_printing=False,
)# linalg dialect
print(aten_ir)