文章目錄
- torch.compiler
- 延伸閱讀
- torch.fft
- 快速傅里葉變換
- 輔助函數
- torch.func
- 什么是可組合的函數變換?
- 為什么需要可組合的函數變換?
- 延伸閱讀
- torch.futures
- torch.fx
- 概述
- 編寫轉換函數
- 圖結構快速入門
- 圖操作
- 直接操作計算圖
- 使用 replace_pattern() 進行子圖重寫
- 圖操作示例
- 代理/回溯機制
- 解釋器模式
- 解釋器模式示例
- 調試
- 簡介
- 變換編寫中的常見陷阱
- 檢查模塊的正確性
- 調試生成的代碼
- 使用 `pdb`
- 打印生成的代碼
- 使用 `GraphModule` 中的 `to_folder` 函數
- 調試轉換過程
- 可用的調試器
- 符號追蹤的局限性
- 動態控制流
- 靜態控制流
- 非`torch`函數
- 使用 `Tracer` 類自定義追蹤功能
- 葉子模塊
- 雜項說明
- API 參考
- torch.fx.experimental
- torch.fx.experimental.symbolic_shapes
- torch.fx.experimental.proxy_tensor
torch.compiler
torch.compiler
是一個命名空間,通過它向用戶開放了一些內部編譯器方法。該命名空間中的主要功能和特性是 torch.compile
。
torch.compile
是 PyTorch 2.x 引入的一個函數,旨在解決 PyTorch 中精確圖捕獲的問題,最終幫助軟件工程師加速運行他們的 PyTorch 程序。torch.compile
使用 Python 編寫,標志著 PyTorch 從 C++ 向 Python 的過渡。
torch.compile
利用了以下底層技術:
- TorchDynamo (torch._dynamo) 是一個內部 API,它使用 CPython 的 Frame Evaluation API 功能來安全捕獲 PyTorch 計算圖。通過
torch.compiler
命名空間向 PyTorch 用戶開放可用方法。 - TorchInductor 是
torch.compile
默認的深度學習編譯器,為多種加速器和后端生成快速代碼。需要通過后端編譯器才能實現torch.compile
的加速效果。對于 NVIDIA、AMD 和 Intel GPU,它使用 OpenAI Triton 作為關鍵構建塊。 - AOT Autograd 不僅能捕獲用戶級代碼,還能捕獲反向傳播,實現"提前"捕獲反向傳遞。這使得 TorchInductor 能夠同時加速前向和反向傳遞。
注意:在本文檔中,術語 torch.compile
、TorchDynamo 和 torch.compiler
有時會互換使用。
如上所述,要通過 TorchDynamo 運行更快的工作流,torch.compile
需要一個后端將捕獲的計算圖轉換為快速機器碼。不同的后端會帶來不同的優化效果。默認后端是 TorchInductor(也稱為 inductor)。TorchDynamo 還支持由合作伙伴開發的一系列后端,可以通過運行 torch.compiler.list_backends()
查看,每個后端都有其可選依賴項。
一些最常用的后端包括:
訓練和推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="inductor") | 使用 TorchInductor 后端。了解更多 |
torch.compile(m, backend="cudagraphs") | 使用 AOT Autograd 的 CUDA 圖。了解更多 |
torch.compile(m, backend="ipex") | 在 CPU 上使用 IPEX。了解更多 |
torch.compile(m, backend="onnxrt") | 使用 ONNX Runtime 在 CPU/GPU 上進行訓練。了解更多 |
僅推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="tensorrt") | 使用 Torch-TensorRT 進行推理優化。需要在調用腳本中 import torch_tensorrt 來注冊后端。了解更多 |
torch.compile(m, backend="ipex") | 在 CPU 上使用 IPEX 進行推理。了解更多 |
torch.compile(m, backend="tvm") | 使用 Apache TVM 進行推理優化。了解更多 |
torch.compile(m, backend="openvino") | 使用 OpenVINO 進行推理優化。了解更多 |
延伸閱讀
PyTorch 用戶入門指南
- 快速入門
- torch.compiler API 參考
- torch.compiler.config 配置
- TorchDynamo 細粒度追蹤 API
- AOTInductor: Torch.Export 模型的預編譯方案
- TorchInductor GPU 性能分析
- torch.compile 性能剖析指南
- 常見問題解答
- torch.compile 故障排查
- PyTorch 2.0 性能看板
PyTorch 開發者深度解析
- Dynamo 架構概覽
- Dynamo 技術深潛
- 動態形狀支持
- PyTorch 2.0 NNModule 支持
- 后端開發最佳實踐
- CUDA 圖樹優化
- 偽張量機制
PyTorch 后端供應商指南
- 自定義后端開發
- ATen IR 圖轉換開發
- 中間表示層詳解
torch.fft
離散傅里葉變換及相關函數。
快速傅里葉變換
fft | 計算input 的一維離散傅里葉變換 |
---|---|
ifft | 計算input 的一維離散傅里葉逆變換 |
fft2 | 計算input 的二維離散傅里葉變換 |
ifft2 | 計算input 的二維離散傅里葉逆變換 |
fftn | 計算input 的N維離散傅里葉變換 |
ifftn | 計算input 的N維離散傅里葉逆變換 |
rfft | 計算實數input 的一維傅里葉變換 |
irfft | 計算rfft() 的逆變換 |
rfft2 | 計算實數input 的二維離散傅里葉變換 |
irfft2 | 計算rfft2() 的逆變換 |
rfftn | 計算實數input 的N維離散傅里葉變換 |
irfftn | 計算rfftn() 的逆變換 |
hfft | 計算Hermitian對稱input 信號的一維離散傅里葉變換 |
ihfft | 計算hfft() 的逆變換 |
hfft2 | 計算Hermitian對稱input 信號的二維離散傅里葉變換 |
ihfft2 | 計算實數input 的二維離散傅里葉逆變換 |
hfftn | 計算Hermitian對稱input 信號的N維離散傅里葉變換 |
ihfftn | 計算實數input 的N維離散傅里葉逆變換 |
輔助函數
fftfreq | 計算大小為 n 的信號的離散傅里葉變換采樣頻率。 |
---|---|
rfftfreq | 計算大小為 n 的信號在使用 rfft() 時的采樣頻率。 |
fftshift | 對由 fftn() 提供的 n 維 FFT 數據進行重新排序,使負頻率項優先。 |
ifftshift | fftshift() 的逆操作。 |
torch.func
torch.func(前身為"functorch")是為PyTorch提供的JAX風格可組合函數變換工具。
注意:該庫目前處于測試階段。
這意味著這些功能基本可用(除非另有說明),且我們(PyTorch團隊)將持續推進該庫的發展。但API可能會根據用戶反饋進行調整,且尚未完全覆蓋所有PyTorch操作。
如果您對API有改進建議,或希望支持特定使用場景,請提交GitHub issue或直接聯系我們。我們非常期待了解您如何使用這個庫。
什么是可組合的函數變換?
- 函數變換是一種高階函數,它接受一個數值函數作為輸入,并返回一個新函數來計算不同的量。
torch.func
提供了自動微分變換(例如grad(f)
返回計算f
梯度的函數)、向量化/批處理變換(例如vmap(f)
返回對輸入批次執行f
的函數)等多種變換。- 這些函數變換可以任意組合使用。例如,組合
vmap(grad(f))
可以計算單樣本梯度(per-sample-gradients),這是當前標準 PyTorch 無法高效計算的量。
為什么需要可組合的函數變換?
目前在 PyTorch 中實現以下用例較為棘手:
- 計算逐樣本梯度(或其他逐樣本量)
- 在單臺機器上運行模型集成
- 在 MAML 內循環中高效批處理任務
- 高效計算雅可比矩陣和海森矩陣
- 高效計算批量雅可比矩陣和海森矩陣
通過組合使用 vmap()
、grad()
和 vjp()
變換,我們無需為每個用例單獨設計子系統即可實現上述功能。這種可組合函數變換的理念源自 JAX 框架。
延伸閱讀
- torch.func 快速指南
- 什么是 torch.func?
- 為什么需要可組合函數變換?
- 有哪些變換方法?
- torch.func API 參考
- 函數變換
- torch.nn.Module 工具集
- 調試工具
- 使用限制
- 通用限制
- torch.autograd API
- vmap 限制
- 隨機性控制
- 從 functorch 遷移到 torch.func
- 函數變換
- 神經網絡模塊工具
- functorch.compile
torch.futures
該包提供了一種 Future
類型,用于封裝異步執行過程,并提供一組實用函數來簡化對 Future
對象的操作。目前,Future
類型主要被 分布式RPC框架 使用。
class torch.futures.Future(*, devices=None)
Wrapper around a torch._C.Future
which encapsulates an asynchronous
execution of a callable, e.g. rpc_async()
. It also exposes a set of APIs to add callback functions and set results.
Warning: GPU support is a beta feature, subject to changes.
add_done_callback(callback)
將給定的回調函數附加到此Future
上,該回調函數將在Future
完成時運行。可以向同一個Future
添加多個回調,但無法保證它們的執行順序。回調函數必須接受一個參數,即對此Future
的引用。回調函數可以使用value()
方法獲取值。請注意,如果此Future
已經完成,給定的回調將立即內聯執行。
我們建議使用then()
方法,因為它提供了一種在回調完成后進行同步的方式。如果回調不返回任何內容,add_done_callback
可能更高效。但then()
和add_done_callback
在底層使用相同的回調注冊API。
對于GPU張量,此方法的行為與then()
相同。
參數
callback (
Future)
– 一個可調用對象,接受一個參數,即對此Future
的引用。
注意:請注意,如果回調函數拋出異常,無論是由于原始future以異常完成并調用fut.wait()
,還是由于回調中的其他代碼,都必須仔細處理錯誤。例如,如果此回調隨后完成了其他future,這些future不會被標記為以錯誤完成,用戶需要獨立處理這些future的完成/等待。
示例:
>>> def callback(fut):
... print("This will run after the future has finished.")
... print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
This will run after the future has finished.
5
done()
如果該Future
已完成則返回True
。當Future
包含結果或異常時即視為完成。
如果值包含位于GPU上的張量,即使填充這些張量的異步內核尚未在設備上完成運行,Future.done()
仍會返回True
,因為在此階段結果已可被使用(前提是執行適當的同步操作,參見wait()
)。
返回類型:bool
set_exception(result)
為這個 Future
設置一個異常,這將標記該 Future
以錯誤狀態完成,并觸發所有已附加的回調。請注意,當對此 Future
調用 wait()/value() 時,此處設置的異常
將被內聯拋出。
參數
result ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException "(in Python v3.13)"))
– 該Future
的異常對象。
示例
>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
Traceback (most recent call last):
...
ValueError: foo
set_result(result)
為這個Future
設置結果,這將標記該Future
為已完成狀態并觸發所有關聯的回調。需要注意的是,一個Future
不能被標記為已完成兩次。
如果結果包含位于GPU上的張量,即使填充這些張量的異步內核尚未在設備上完成運行,只要調用此方法時這些內核所入隊的流被設置為當前流,仍可調用此方法。簡而言之,在啟動這些內核后立即調用此方法是安全的,無需額外同步,前提是期間不切換流。此方法會在所有相關當前流上記錄事件,并利用它們確保此Future
的所有消費者都能得到正確調度。
參數
result ( object )
- 該Future
的結果對象。
示例:
>>> import threading
>>> import time
>>> def slow_set_future(fut, value):
... time.sleep(0.5)
... fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
... target=slow_set_future,
... args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()
then(callback)
將給定的回調函數附加到此Future
上,該回調函數將在Future
完成時運行。可以向同一個Future
添加多個回調,但無法保證它們的執行順序(如需確保特定順序,請考慮鏈式調用:fut.then(cb1).then(cb2)
)。回調函數必須接受一個參數,即對此Future
的引用。回調函數可通過value()
方法獲取值。請注意,如果此Future
已完成,給定的回調將立即內聯執行。
如果Future
的值包含位于GPU上的張量,回調可能在填充這些張量的異步內核尚未在設備上完成執行時就被調用。不過,回調將通過設置為當前的一些專用流(從全局池中獲取)被調用,這些流將與那些內核同步。因此,回調對這些張量執行的任何操作都將在內核完成后調度到設備上。換句話說,只要回調不切換流,它就可以安全地操作結果而無需額外同步。這與wait()
的非阻塞行為類似。
類似地,如果回調返回的值包含位于GPU上的張量,即使生成這些張量的內核仍在設備上運行,回調也可以這樣做,前提是回調在執行期間沒有切換流。如果想要切換流,必須注意與原始流重新同步,即回調被調用時當前的流。
參數
callback (
Callable)
– 一個以該Future
為唯一參數的可調用對象。
返回
一個新的Future
對象,它持有callback
的返回值,并將在給定callback
完成時標記為已完成。
返回類型
Future[S]
注意:請注意,如果回調函數拋出異常,無論是通過原始future以異常完成并調用fut.wait()
,還是通過回調中的其他代碼,then
返回的future將適當地標記為遇到錯誤。但是,如果此回調隨后完成其他future,這些future不會標記為以錯誤完成,用戶需負責獨立處理這些future的完成/等待。
示例:
>>> def callback(fut):
... print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
... lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
RPC return value is 5、Chained cb done. None
value()
獲取已完成的Future對象的值。
此方法僅應在調用wait()
完成后,或在傳遞給then()
的回調函數內部使用。其他情況下,該Future
可能尚未持有值,調用value()
可能會失敗。
如果值包含位于GPU上的張量,此方法將不會執行任何額外的同步操作。此類同步應事先通過調用wait()
單獨完成(回調函數內部除外,因為then()
已自動處理此情況)。
返回值
該Future
持有的值。如果創建該值的函數(回調或RPC)拋出錯誤,此value()
方法同樣會拋出錯誤。
返回類型:T
wait()
等待直到該 Future
的值準備就緒。
如果值包含位于 GPU 上的張量,則會與設備上異步填充這些張量的內核執行額外的同步操作。此類同步是非阻塞的,這意味著 wait()
會在當前流中插入必要的指令,以確保后續在這些流上排隊的操作能正確安排在異步內核之后執行。但一旦完成指令插入,即使這些內核仍在運行,wait()
也會立即返回。只要不切換流,在訪問和使用這些值時無需進一步同步。
返回值:此 Future
持有的值。如果創建該值的函數(回調或 RPC)拋出錯誤,此 wait
方法同樣會拋出錯誤。
返回類型:T
torch.futures.collect_all(futures)
將提供的 Future
對象收集到一個統一的組合 Future
中,該組合 Future 會在所有子 Future 完成時完成。
參數
futures (list)
– 一個包含Future
對象的列表。
返回
返回一個 Future
對象,該對象關聯到傳入的 Future 列表。
返回類型
Future[list [torch.jit.Future]]
示例
>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1])
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
fut1 result = 1
torch.futures.wait_all(futures)
等待所有提供的 futures 完成,并返回已完成值的列表。如果任一 future 遇到錯誤,該方法將提前退出并報告錯誤,而不會等待其他 futures 完成。
參數
futures (list)
– 一個Future
對象列表。
返回值:已完成 Future
結果的列表。如果對任何 Future
調用 wait
時拋出錯誤,該方法也會拋出錯誤。
返回類型:list
torch.fx
概述
FX 是一個供開發者使用的工具包,用于轉換 nn.Module
實例。FX 包含三個核心組件:符號追蹤器、中間表示和 Python 代碼生成。以下是這些組件的實際應用演示:
import torch# Simple module for demonstration
class MyModule(torch.nn.Module):def __init__(self) -None:super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():%x : [num_users=1] = placeholder[target=x]%param : [num_users=1] = get_attr[target=param]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param; x = param = Nonelinear = self.linear(add); add = Noneclamp = linear.clamp(min = 0.0, max = 1.0); linear = Nonereturn clamp
"""
符號追蹤器(symbolic tracer)對Python代碼執行"符號執行"。它通過代碼傳遞稱為Proxy的虛擬值,并記錄對這些Proxy的操作。有關符號追蹤的更多信息,請參閱symbolic_trace()
和Tracer
文檔。
中間表示(intermediate representation)是符號追蹤過程中記錄操作的容器。它由一組節點組成,這些節點表示函數輸入、調用點(指向函數、方法或torch.nn.Module
實例)以及返回值。有關IR的更多信息,請參閱Graph
文檔。IR是應用轉換的基礎格式。
Python代碼生成功能使FX成為Python到Python(或Module到Module)的轉換工具包。對于每個Graph IR,我們都可以生成符合Graph語義的有效Python代碼。這個功能被封裝在GraphModule
中,它是一個torch.nn.Module
實例,包含一個Graph
以及從Graph生成的forward
方法。
這些組件(符號追蹤→中間表示→轉換→Python代碼生成)共同構成了FX的Python到Python轉換流程。此外,這些組件也可以單獨使用。例如,符號追蹤可以單獨用于捕獲代碼形式進行分析(而非轉換)目的。代碼生成可以用于通過編程方式生成模型,例如從配置文件生成。FX有許多用途!
在示例庫中可以找到幾個轉換示例。
編寫轉換函數
什么是FX轉換?本質上,它是一個形如下列的函數。
import torch
import torch.fxdef transform(m: nn.Module, tracer_class : type = torch.fx.Tracer) -torch.nn.Module:# Step 1: Acquire a Graph representing the code in `m`# NOTE: torch.fx.symbolic_trace is a wrapper around a call to # fx.Tracer.trace and constructing a GraphModule. We'll# split that out in our transform to allow the caller to # customize tracing behavior.graph : torch.fx.Graph = tracer_class().trace(m)# Step 2: Modify this Graph or create a new onegraph = ...# Step 3: Construct a Module to returnreturn torch.fx.GraphModule(m, graph)
您的轉換器將接收一個 torch.nn.Module
,從中獲取 Graph
,進行一些修改后返回一個新的 torch.nn.Module
。您應該將 FX 轉換器返回的 torch.nn.Module
視為與常規 torch.nn.Module
完全相同——可以將其傳遞給另一個 FX 轉換器、傳遞給 TorchScript 或直接運行它。確保 FX 轉換器的輸入和輸出均為 torch.nn.Module
將有助于實現組合性。
注意:也可以直接修改現有的 GraphModule
而不創建新實例,例如:
import torch
import torch.fxdef transform(m : nn.Module) -nn.Module:gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)# Modify gm.graph# <...># Recompile the forward() method of `gm` from its Graphgm.recompile()return gm
請注意,你必須調用 GraphModule.recompile()
方法,使生成的 forward()
方法與修改后的 Graph
保持同步。
假設你已經傳入了一個經過追蹤轉換為 Graph
的 torch.nn.Module
,現在主要有兩種方法來構建新的 Graph
。
圖結構快速入門
關于圖的語義完整說明可以參考 Graph
文檔,這里我們主要介紹基礎概念。Graph
是一種數據結構,用于表示 GraphModule
上的方法。其核心需要描述以下信息:
- 方法的輸入參數是什么?
- 方法內部運行了哪些操作?
- 方法的輸出(即返回值)是什么?
這三個概念都通過 Node
實例來表示。下面通過一個簡單示例來說明:
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)gm.graph.print_tabular()
這里我們定義一個演示用的模塊 MyModule
,實例化后進行符號追蹤,然后調用 Graph.print_tabular()
方法打印該 Graph
的節點表格:
操作碼 | 名稱 | 目標 | 參數 | 關鍵字參數 |
---|---|---|---|---|
placeholder | x | x | () | {} |
get_attr | linear_weight | linear.weight | () | {} |
call_function | add_1 | <built-in function add | (x, linear_weight) | {} |
call_module | linear_1 | linear | (add_1,) | {} |
call_method | relu_1 | relu | (linear_1,) | {} |
call_function | sum_1 | <built-in method sum … | (relu_1,) | {‘dim’: -1} |
call_function | topk_1 | <built-in method topk … | (sum_1, 3) | {} |
output | output | output | (topk_1,) | {} |
通過這些信息,我們可以回答之前提出的問題:
- 方法的輸入是什么?
在FX中,方法輸入通過特殊的placeholder
節點指定。本例中有一個目標為x
的placeholder
節點,表示存在一個名為x的(非self)參數。 - 方法內部有哪些操作?
get_attr
、call_function
、call_module
和call_method
節點表示方法中的操作。這些節點的完整語義說明可參考Node
文檔。 - 方法的返回值是什么?
在Graph
中,返回值由特殊的output
節點指定。
現在我們已經了解FX中代碼表示的基本原理,接下來可以探索如何編輯 Graph
。
圖操作
直接操作計算圖
構建新Graph
的一種方法是直接操作原有計算圖。為此,我們可以簡單地獲取通過符號追蹤得到的Graph
并進行修改。例如,假設我們需要將所有torch.add()
調用替換為torch.mul()
調用。
import torch
import torch.fx# Sample module
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)def transform(m: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph : fx.Graph = tracer_class().trace(m)# FX represents its Graph as an ordered list of # nodes, so we can iterate through them.for node in graph.nodes:# Checks if we're calling a function (i.e:# torch.add)if node.op == 'call_function':# The target attribute is the function# that call_function calls.if node.target == torch.add:node.target = torch.mulgraph.lint() # Does some checks to make sure the # Graph is well-formed.return fx.GraphModule(m, graph)
我們還可以進行更復雜的 Graph
重寫操作,例如刪除或追加節點。為了輔助這些轉換,FX 提供了一些用于操作計算圖的實用函數,這些函數可以在 Graph
文檔中找到。
下面展示了一個使用這些 API 追加 torch.relu()
調用的示例。
# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node):# Insert a new `call_function` node calling `torch.relu`new_node = traced.graph.call_function(torch.relu, args=(node,))# We want all places that used the value of `node` to # now use that value after the `relu` call we've added.# We use the `replace_all_uses_with` API to do this.node.replace_all_uses_with(new_node)
對于僅包含替換操作的簡單轉換,您也可以使用子圖重寫器。
使用 replace_pattern() 進行子圖重寫
FX 在直接圖操作的基礎上提供了更高層次的自動化能力。replace_pattern()
API 本質上是一個用于編輯 Graph
的"查找/替換"工具。它允許你指定一個 pattern
(模式)和 replacement
(替換)函數,然后會追蹤這些函數,在圖中找到與 pattern
圖匹配的操作組實例,并用 replacement
圖的副本替換這些實例。這可以極大地自動化繁瑣的圖操作代碼,隨著轉換邏輯變得復雜,手動操作會變得難以維護。
圖操作示例
- 替換單個操作符
- 卷積/批量歸一化融合
- replace_pattern:基礎用法
- 量化
- 逆變換
代理/回溯機制
另一種操作 Graph
的方式是復用符號追蹤中使用的 Proxy
機制。例如,假設我們需要編寫一個將 PyTorch 函數分解為更小操作的轉換器:將每個 F.relu(x)
調用轉換為 (x > 0) * x
。傳統做法可能是通過圖重寫來插入比較和乘法操作,然后清理原始的 F.relu
。但借助 Proxy
對象,我們可以自動將操作記錄到 Graph
中來實現這一過程。
具體實現時,只需將需要插入的操作寫成常規 PyTorch 代碼,并用 Proxy
對象作為參數調用該代碼。這些 Proxy
對象會捕獲對其執行的操作,并將其追加到 Graph
中。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):return (x 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose(model: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:"""Decompose `model` into smaller constituent operations.Currently,this only supports decomposing ReLU into itsmathematical definition: (x 0) * x"""graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()env = {}tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)for node in graph.nodes:if node.op == 'call_function' and node.target in decomposition_rules:# By wrapping the arguments with proxies, # we can dispatch to the appropriate# decomposition rule and implicitly add it# to the Graph by symbolically tracing it.proxy_args = [fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]output_proxy = decomposition_rules[node.target](proxy_args)# Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception.# We need to extract the underlying `Node` from the `Proxy`# to use it in subsequent iterations of this transform.new_node = output_proxy.nodeenv[node.name] = new_nodeelse:# Default case: we don't have a decomposition rule for this # node, so just copy the node over into the new graph.new_node = new_graph.node_copy(node, lambda x: env[x.name])env[node.name] = new_nodereturn fx.GraphModule(model, new_graph)
除了避免顯式的圖操作外,使用Proxy
還允許您將重寫規則指定為原生Python代碼。對于需要大量重寫規則的轉換(如vmap或grad),這通常可以提高規則的可讀性和可維護性。
需要注意的是,在調用Proxy
時,我們還傳遞了一個指向底層變量圖的追蹤器。這樣做是為了防止當圖中的操作是n元操作時(例如add是二元運算符),調用Proxy
不會創建多個圖追蹤器實例,否則可能導致意外的運行時錯誤。特別是在底層操作不能安全地假設為一元操作時,我們推薦使用這種Proxy
方法。
一個使用Proxy
進行Graph
操作的實際示例可以在這里找到。
解釋器模式
在FX中,一個實用的代碼組織模式是遍歷Graph
中的所有Node
并執行它們。這種模式可用于多種場景,包括:
- 運行時分析流經計算圖的值
- 通過
Proxy
重新追蹤來實現代碼轉換
例如,假設我們想運行一個GraphModule
,并在運行時記錄節點上torch.Tensor
的形狀和數據類型屬性。實現代碼可能如下:
import torch
import torch.fx
from torch.fx.node import Nodefrom typing import Dictclass ShapeProp:"""Shape propagation. This class takes a `GraphModule`.Then, its `propagate` method executes the `GraphModule`node-by-node with the given arguments. As each operationexecutes, the ShapeProp class stores away the shape and element type for the output values of each operation on the `shape` and `dtype` attributes of the operation's`Node`."""def __init__(self, mod):self.mod = modself.graph = mod.graphself.modules = dict(self.mod.named_modules())def propagate(self, args):args_iter = iter(args)env : Dict[str, Node] = {}def load_arg(a):return torch.fx.graph.map_arg(a, lambda n: env[n.name])def fetch_attr(target : str):target_atoms = target.split('.')attr_itr = self.modfor i, atom in enumerate(target_atoms):if not hasattr(attr_itr, atom):raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")attr_itr = getattr(attr_itr, atom)return attr_itrfor node in self.graph.nodes:if node.op == 'placeholder':result = next(args_iter)elif node.op == 'get_attr':result = fetch_attr(node.target)elif node.op == 'call_function':result = node.target(load_arg(node.args), *load_arg(node.kwargs))elif node.op == 'call_method':self_obj, args = load_arg(node.args)kwargs = load_arg(node.kwargs)result = getattr(self_obj, node.target)(args, *kwargs)elif node.op == 'call_module':result = self.modules[node.target](load_arg(node.args), *load_arg(node.kwargs))# This is the only code specific to shape propagation.# you can delete this `if` branch and this becomes# a generic GraphModule interpreter.if isinstance(result, torch.Tensor):node.shape = result.shapenode.dtype = result.dtypeenv[node.name] = resultreturn load_arg(self.graph.result)
如你所見,為FX實現一個完整的解釋器并不復雜,但卻非常實用。為了簡化這一模式的使用,我們提供了Interpreter
類,它封裝了上述邏輯,允許通過方法重寫來覆蓋解釋器執行的某些方面。
除了執行操作外,我們還可以通過向解釋器傳遞Proxy
值來生成新的計算圖。
類似地,我們提供了Transformer
類來封裝這種模式。Transformer
的行為與Interpreter
類似,但不同于調用run
方法從模塊獲取具體輸出值,你需要調用Transformer.transform()
方法來返回一個新的GraphModule
,該模塊會應用你通過重寫方法設置的任何轉換規則。
解釋器模式示例
- 形狀傳播
- 性能分析器
調試
簡介
在編寫轉換代碼的過程中,我們的代碼往往不會一開始就完全正確。這時就需要進行調試。關鍵在于采用逆向思維:首先檢查調用生成模塊的結果,驗證其正確性;接著審查并調試生成的代碼;最后追溯導致生成代碼的轉換過程并進行調試。
如果您不熟悉調試工具,請參閱輔助章節可用調試工具。
變換編寫中的常見陷阱
set
迭代順序的不確定性。在Python中,set
數據類型是無序的。例如,使用set
來存儲Node
等對象集合可能導致意外的非確定性行為。比如當迭代一組Node
并將其插入Graph
時,由于set
數據類型是無序的,輸出程序中操作的順序將是非確定性的,且每次程序調用都可能變化。
推薦的替代方案是使用dict
數據類型。自Python 3.7起(以及cPython 3.6起),dict
保持了插入順序。通過將需要去重的值存儲在dict
的鍵中,可以等效地實現set
的功能。
檢查模塊的正確性
由于大多數深度學習模塊的輸出都是浮點型 torch.Tensor
實例,因此檢查兩個 torch.nn.Module
的結果是否相等并不像簡單的相等性檢查那樣直接。為了說明這一點,我們來看一個示例:
import torch
import torch.fx
import torchvision.models as modelsdef transform(m : torch.nn.Module) -torch.nn.Module:gm = torch.fx.symbolic_trace(m)# Imagine we're doing some transforms here# <...>gm.recompile()return gmresnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)input_image = torch.randn(5, 3, 224, 224)assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在這里,我們嘗試使用==
相等運算符來檢查兩個深度學習模型的值是否相等。然而,這種做法存在兩個問題:首先,該運算符返回的是張量而非布爾值;其次,浮點數值的比較應考慮誤差范圍(或epsilon),以解決浮點運算不可交換性的問題(詳見此處)。
我們可以改用torch.allclose()
函數,它會基于相對和絕對容差閾值進行近似比較:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
這是我們工具箱中的第一個工具,用于檢查轉換后的模塊與參考實現相比是否按預期運行。
調試生成的代碼
由于 FX 在 GraphModule
上生成 forward()
函數,使用傳統的調試技術(如 print
語句或 pdb
)會不太直觀。幸運的是,我們有多種方法可以用來調試生成的代碼。
使用 pdb
通過調用 pdb
可以進入正在運行的程序進行調試。雖然表示 Graph
的代碼不在任何源文件中,但當執行前向傳播時,我們仍然可以手動使用 pdb
進入該代碼進行調試。
import torch
import torch.fx
import torchvision.models as modelsdef my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph = tracer_class().trace(inp)# Transformation logic here# <...># Return new Modulereturn fx.GraphModule(inp, graph)my_module = models.resnet18()
my_module_transformed = my_pass(my_module)input_value = torch.randn(5, 3, 224, 224)# When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line
import pdb; pdb.set_trace()my_module_transformed(input_value)
打印生成的代碼
如果需要多次運行相同的代碼,使用pdb
逐步調試到目標代碼可能會有些繁瑣。這種情況下,一個簡單的方法是將生成的forward
傳遞代碼直接復制粘貼到你的代碼中,然后在那里進行檢查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):x = self.xadd_1 = x + y; x = y = Nonereturn add_1
"""# Subclass the original Module
class SubclassM(M):def __init__(self):super().__init__()# Paste the generated `forward` function (the one we printed and # copied above) heredef forward(self, y):x = self.xadd_1 = x + y; x = y = Nonereturn add_1# Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
中的 to_folder
函數
GraphModule.to_folder()
是 GraphModule
中的一個方法,它允許你將生成的 FX 代碼導出到一個文件夾。雖然像打印生成的代碼中那樣直接復制前向傳播代碼通常已經足夠,但使用 to_folder
可以更方便地檢查模塊和參數。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
運行上述示例后,我們可以查看foo/module.py
中的代碼,并根據需要進行修改(例如添加print
語句或使用pdb
)來調試生成的代碼。
調試轉換過程
既然我們已經確認是轉換過程生成了錯誤代碼,現在就該調試轉換本身了。首先,我們會查閱文檔中的符號追蹤限制部分。在確認追蹤功能按預期工作后,我們的目標就轉變為找出GraphModule
轉換過程中出現的問題。編寫轉換部分可能有快速解決方案,如果沒有的話,我們還可以通過多種方式來檢查追蹤模塊:
# Sample Module
class M(torch.nn.Module):def forward(self, x, y):return x + y# Create an instance of `M`
m = M()# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity.
traced = symbolic_trace(m)# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):add = x + y; x = y = Nonereturn add
"""# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():%x : [num_users=1] = placeholder[target=x]%y : [num_users=1] = placeholder[target=y]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})return add
"""# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add (x, y) {}
output output output (add,) {}
"""
通過使用上述工具函數,我們可以對比應用轉換前后的追蹤模塊。有時,簡單的視覺對比就足以定位錯誤。如果問題仍不明確,下一步可以嘗試使用 pdb
這類調試器。
以上述示例為基礎,請看以下代碼:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:# Get the Graph from our traced Moduleg = tracer_class().trace(module)"""Transformations on `g` go here"""return fx.GraphModule(module, g)# Transform the Graph
transformed = transform_graph(traced)# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
以上述示例為例,假設調用print(traced)
時發現轉換過程中存在錯誤。我們需要通過調試器定位問題根源。啟動pdb
調試會話后,可以在transform_graph(traced)
處設置斷點,然后按s
鍵"步入"該函數調用,實時觀察轉換過程。
另一個有效方法是修改print_tabular
方法,使其輸出圖中節點的不同屬性(例如查看節點的input_nodes
和users
關系)。
可用的調試器
最常用的Python調試器是pdb。你可以通過在命令行輸入python -m pdb FILENAME.py
來以"調試模式"啟動程序,其中FILENAME
是你要調試的文件名。之后,你可以使用pdb
的調試器命令逐步執行正在運行的程序。通常的做法是在啟動pdb
時設置一個斷點(b LINE-NUMBER
),然后調用c
讓程序運行到該斷點處。這樣可以避免你不得不使用s
或n
逐行執行代碼才能到達想要檢查的部分。或者,你也可以在想中斷的代碼行前寫入import pdb; pdb.set_trace()
。如果添加了pdb.set_trace()
,當你運行程序時它會自動進入調試模式(換句話說,你只需在命令行輸入python FILENAME.py
而不用輸入python -m pdb FILENAME.py
)。一旦以調試模式運行文件,你就可以使用特定命令逐步執行代碼并檢查程序的內部狀態。網上有很多關于pdb
的優秀教程,包括RealPython的《Python Debugging With Pdb》。
像PyCharm或VSCode這樣的IDE通常內置了調試器。在你的IDE中,你可以選擇:a)通過調出IDE中的終端窗口(例如在VSCode中選擇View → Terminal)使用pdb
,或者b)使用內置的調試器(通常是pdb
的圖形化封裝)。
符號追蹤的局限性
FX 采用符號追蹤(又稱符號執行)系統,以可轉換/可分析的形式捕獲程序語義。該系統具有以下特點:
- 追蹤性:通過實際執行程序(實際是
torch.nn.Module
或函數)來記錄操作 - 符號性:執行過程中流經程序的數據并非真實數據,而是符號(FX術語中稱為
Proxy
)
雖然符號追蹤適用于大多數神經網絡代碼,但它仍存在一些局限性。
動態控制流
符號追蹤的主要局限在于目前不支持動態控制流。也就是說,當循環或if
語句的條件可能依賴于程序輸入值時,就無法處理。
例如,我們來看以下程序:
def func_to_trace(x):if x.sum() 0:return torch.relu(x)else:return torch.neg(x)traced = torch.fx.symbolic_trace(func_to_trace)
"""<...>File "dyn.py", line 6, in func_to_traceif x.sum() 0:File "pytorch/torch/fx/proxy.py", line 155, in __bool__return self.tracer.to_bool(self)File "pytorch/torch/fx/proxy.py", line 85, in to_boolraise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
語句的條件依賴于x.sum()
的值,而該值又依賴于函數輸入x
。由于x
可能發生變化(例如向追蹤函數傳入新的輸入張量時),這就形成了動態控制流。回溯信息會沿著代碼向上追溯,展示這種情況發生的位置。
靜態控制流
另一方面,系統支持所謂的靜態控制流。靜態控制流指的是那些在多次調用中值不會改變的循環或if
語句。通常在PyTorch程序中,這種控制流出現在根據超參數決定模型架構的代碼中。舉個具體例子:
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# This if-statement is so-called static control flow.# Its condition does not depend on any input valuesif self.do_activation:x = torch.relu(x)return xwithout_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):linear_1 = self.linear(x); x = Nonereturn linear_1
"""traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):linear_1 = self.linear(x); x = Nonerelu_1 = torch.relu(linear_1); linear_1 = Nonereturn relu_1
"""
if self.do_activation
這個條件語句不依賴于任何函數輸入,因此它是靜態的。do_activation
可以被視為一個超參數,當 MyModule
的不同實例使用不同參數值時,生成的代碼軌跡也會不同。這是一種有效模式,符號追蹤功能支持這種模式。
許多動態控制流的實例在語義上其實是靜態控制流。通過消除對輸入值的數據依賴,這些實例可以支持符號追蹤。具體方法包括:
- 將值移至
Module
屬性中 - 在符號追蹤期間將具體值綁定到參數上
def f(x, flag):if flag: return xelse: return x*2fx.symbolic_trace(f) # Fails!fx.symbolic_trace(f, concrete_args={'flag': True})
在真正動態控制流的情況下,包含此類代碼的程序部分可以被追蹤為對方法的調用(參見使用Tracer類自定義追蹤)或函數調用(參見wrap()
),而不是直接追蹤這些代碼本身。
非torch
函數
FX采用__torch_function__
作為攔截調用的機制(更多技術細節請參閱技術概覽)。某些函數(如Python內置函數或math
模塊中的函數)不受__torch_function__
覆蓋,但我們仍希望在符號追蹤中捕獲它們。例如:
import torch
import torch.fx
from math import sqrtdef normalize(x):"""Normalize `x` by the size of the batch dimension"""return x / sqrt(len(x))# It's valid Python code
normalize(torch.rand(3, 4))traced = torch.fx.symbolic_trace(normalize)
"""<...>File "sqrt.py", line 9, in normalizereturn x / sqrt(len(x))File "pytorch/torch/fx/proxy.py", line 161, in __len__raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
錯誤提示表明內置函數 len
不被支持。
我們可以通過 wrap()
API 將此類函數記錄為跟蹤中的直接調用:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')traced = torch.fx.symbolic_trace(normalize)print(traced.code)
"""
import math
def forward(self, x):len_1 = len(x)sqrt_1 = math.sqrt(len_1); len_1 = Nonetruediv = x / sqrt_1; x = sqrt_1 = Nonereturn truediv
"""
使用 Tracer
類自定義追蹤功能
Tracer
類是 symbolic_trace
功能的基礎實現類。通過繼承 Tracer 類,可以自定義追蹤行為,例如:
class MyCustomTracer(torch.fx.Tracer):# Inside here you can override various methods# to customize tracing. See the `Tracer` API# referencepass# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):def forward(self, x):return torch.relu(x) + torch.ones(3, 4)mod = MyModule()traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
葉子模塊
葉子模塊是指在符號追蹤過程中作為調用出現,而不會被繼續追蹤的模塊。默認的葉子模塊集合由標準torch.nn
模塊實例組成。例如:
class MySpecialSubmodule(torch.nn.Module):def forward(self, x):return torch.neg(x)class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(3, 4)self.submod = MySpecialSubmodule()def forward(self, x):return self.submod(self.linear(x))traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):linear_1 = self.linear(x); x = Noneneg_1 = torch.neg(linear_1); linear_1 = Nonereturn neg_1
"""
可以通過重寫 Tracer.is_leaf_module()
來自定義葉子模塊集合。
雜項說明
- 當前無法追蹤張量構造函數(如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
):- 確定性構造函數(
zeros
、ones
)仍可使用,其生成的值會作為常量嵌入追蹤記錄。僅當這些構造函數的參數涉及動態輸入大小時才會出現問題,此時可改用ones_like
或zeros_like
作為替代方案。 - 非確定性構造函數(
rand
、randn
)會將單個隨機值嵌入追蹤記錄,這通常不符合預期行為。變通方法是將torch.randn
包裝在torch.fx.wrap
函數中并調用該包裝函數。
- 確定性構造函數(
(注:保留所有代碼塊及技術術語原貌,被動語態轉為主動表述,長句拆分后保持技術嚴謹性)
@torch.fx.wrap
def torch_randn(x, shape):return torch.randn(shape)def f(x):return x + torch_randn(x, 5)
fx.symbolic_trace(f)
此行為可能在未來的版本中修復。
- 類型注解
-
支持 Python 3 風格的類型注解(例如
func(x : torch.Tensor, y : int) -torch.Tensor
),
并且會通過符號追蹤保留這些注解。 -
目前不支持 Python 2 風格的注釋類型注解
# type: (torch.Tensor, int) -torch.Tensor
。 -
目前不支持函數內部局部變量的類型注解。
- 關于
training
標志和子模塊的注意事項
- 當使用像
torch.nn.functional.dropout
這樣的函數時,通常會傳入self.training
作為訓練參數。在 FX 追蹤過程中,這個值很可能會被固定為一個常量。
import torch
import torch.fxclass DropoutRepro(torch.nn.Module):def forward(self, x):return torch.nn.functional.dropout(x, training=self.training)traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
"""
def forward(self, x):dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = Nonereturn dropout
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
"""
AssertionError: Tensor-likes are not close!Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
"""
然而,當使用標準的 nn.Dropout()
子模塊時,訓練標志會被封裝起來,并且由于保留了 nn.Module
對象模型,可以對其進行修改。
class DropoutRepro2(torch.nn.Module):def __init__(self):super().__init__()self.drop = torch.nn.Dropout()def forward(self, x):return self.drop(x)traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
"""
def forward(self, x):drop = self.drop(x); x = Nonereturn drop
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
由于這一差異,建議將與動態training
標志交互的模塊標記為葉模塊。
API 參考
torch.fx.symbolic_trace(root, concrete_args=None)
符號追蹤 API
給定一個 nn.Module
或函數實例 root
,該 API 會返回一個 GraphModule
,這是通過記錄追蹤 root
時觀察到的操作構建而成的。
concrete_args
參數允許你對函數進行部分特化,無論是為了移除控制流還是數據結構。
例如:
def f(a, b):if b == True:return a else:return a * 2
由于控制流的存在,FX通常無法追蹤此過程。不過,我們可以使用concrete_args
來針對變量b的值進行特化處理,從而實現追蹤:
f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6
請注意,雖然您仍可以傳入不同的b值,但這些值將被忽略。
我們還可以使用concrete_args
來消除函數中對數據結構的處理。這將利用pytrees來展平您的輸入。為了避免過度特化,對于不應特化的值,請傳入fx.PH
。例如:
def f(x):out = 0for v in x.values():out += vreturn outf = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
參數
root (Union[torch.nn.Module, Callable])
- 待追蹤并轉換為圖表示形式的模塊或函數concrete_args (Optional[Dict[str, any]])
- 需要部分特化的輸入參數
返回從root
記錄的操作所創建的模塊。
返回類型:GraphModule
注意:此API保證向后兼容性。
torch.fx.wrap(fn_or_name)
該函數可在模塊級作用域調用,將fn_or_name
注冊為"葉子函數"。
"葉子函數"在FX跟蹤中會保留為CallFunction節點,而不會被進一步跟蹤。
# foo/bar/baz.py
def my_custom_function(x, y):return x * x + y * ytorch.fx.wrap("my_custom_function")def fn_to_be_traced(x, y):# When symbolic tracing, the below call to my_custom_function will be inserted into# the graph rather than tracing it.return my_custom_function(x, y)
該函數也可以等效地用作裝飾器:
# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):return x * x + y * y
包裝函數可以被視為"葉子函數",類似于"葉子模塊"的概念,也就是說,這些函數在FX跟蹤中會保留為調用點,而不會被進一步追蹤。
參數
fn_or_name (Union[str, Callable])
- 當被調用時,要插入到圖中的函數或全局函數名稱
注意:此API保證向后兼容性。
class torch.fx.GraphModule(*args, **kwargs)
GraphModule 是由 fx.Graph 生成的 nn.Module。GraphModule 具有一個 graph
屬性,以及從該 graph
生成的 code
和 forward
屬性。
警告:當重新分配 graph
時,code
和 forward
將自動重新生成。但如果你編輯了 graph
的內容而沒有重新分配 graph
屬性本身,則必須調用 recompile()
來更新生成的代碼。
注意:此 API 保證向后兼容性。
__init__(root, graph, class_name='GraphModule')
構建一個 GraphModule。
參數
root (Union[torch.nn.Module , Dict[str, Any])
–root
可以是 nn.Module 實例,也可以是將字符串映射到任意屬性類型的字典。
當 root
是 Module 時,Graph 的 Nodes 中 target
字段對基于 Module 的對象(通過限定名稱引用)的任何引用,都會從 root
的 Module 層次結構中的相應位置復制到 GraphModule 的模塊層次結構中。
當 root
是字典時,Node 的 target
中找到的限定名稱將直接在字典的鍵中查找。字典映射的對象將被復制到 GraphModule 模塊層次結構中的適當位置。
graph (Graph)
–graph
包含此 GraphModule 用于代碼生成的節點class_name (str)
–name
表示此 GraphModule 的名稱,用于調試目的。如果未設置,所有錯誤消息將報告為源自GraphModule
。將其設置為root
的原始名稱或在轉換上下文中合理的名稱可能會有所幫助。
注意:此 API 保證向后兼容性。
add_submodule(target, m)
將給定的子模塊添加到self
中。
如果target
是子路徑且對應位置尚未存在模塊,此方法會安裝空的模塊。
參數
target (str)
- 新子模塊的完整限定字符串名稱
(參見nn.Module.get_submodule
中的示例了解如何指定完整限定字符串)m (Module)
- 子模塊本身;即我們想要安裝到當前模塊中的實際對象
返回
子模塊是否能夠被插入。要使該方法返回True,target
表示的鏈中每個對象必須滿足以下條件之一:
a) 尚不存在,或
b) 引用的是nn.Module
(而非參數或其他屬性)
返回類型:bool
注意:此API保證向后兼容性。
property code: str
返回從該 GraphModule
底層 Graph
生成的 Python 代碼。
delete_all_unused_submodules()
***
Deletes all unused submodules from `self`.A Module is considered “used” if any one of the following is true:
1、It has children that are used
2、Its forward is called directly via a `call_module` node
3、It has a non-Module attribute that is used from a `get_attr` nodeThis method can be called to clean up an `nn.Module` without
manually calling `delete_submodule` on each unused submodule.
***
Note: Backwards-compatibility for this API is guaranteed.delete_submodule(target)
從self
中刪除指定的子模塊。
如果target
不是有效的目標,則不會刪除該模塊。
參數
target (str)
- 新子模塊的完全限定字符串名稱
(有關如何指定完全限定字符串的示例,請參閱nn.Module.get_submodule
)
返回值
表示目標字符串是否引用了我們要刪除的子模塊。返回值為False
意味著target
不是有效的子模塊引用。
返回類型 : bool
注意:此API保證向后兼容性。
property graph: [Graph](https://pytorch.org/docs/stable/data.html#torch.fx.Graph "torch.fx.graph.Graph")
返回該 GraphModule
底層對應的 Graph
print_readable(print_output=True, include_stride=False, include_device=False, colored=False)
返回為當前 GraphModule 及其子 GraphModule 生成的 Python 代碼
警告:此 API 為實驗性質,且不保證向后兼容性。
recompile()
根據其 graph
屬性重新編譯該 GraphModule。在編輯包含的 graph
后應調用此方法,否則該 GraphModule
生成的代碼將過期。
注意:此 API 保證向后兼容性。
返回類型:PythonCode
to_folder(folder, module_name='FxModule')
將模塊以 module_name
名稱轉儲到 folder
目錄下,以便可以通過 from <folder> import <module_name>
方式導入。
參數:
folder (Union [str, os.PathLike])
: 用于輸出代碼的目標文件夾路徑
module_name (str): 在輸出代碼時使用的頂層模塊名稱
警告:此 API 為實驗性質,不保證向后兼容性。
class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)
Graph
是 FX 中間表示中使用的主要數據結構。
它由一系列 Node
組成,每個節點代表調用點(或其他語法結構)。這些 Node
的集合共同構成了一個有效的 Python 函數。
例如,以下代碼
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)
將生成以下圖表:
print(gm.graph)
graph(x):%linear_weight : [num_users=1] = self.linear.weight%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})return topk_1
關于Graph
中操作的具體語義,請參閱Node
文檔。
注意:本API保證向后兼容性。
__init__(owning_module=None, tracer_cls=None, tracer_extras=None)
構建一個空圖。
注意:此 API 保證向后兼容性。
call_function(the_function, args=None, kwargs=None, type_expr=None)
在Graph
中插入一個call_function
類型的Node
。call_function
節點表示對Python可調用對象的調用,由the_function
指定。
參數
the_function (Callable[...*, Any])
– 要調用的函數。可以是任何PyTorch運算符、Python函數,或屬于builtins
或operator
命名空間的成員。args (Optional[Tuple[Argument*, ...]])
– 傳遞給被調用函數的位置參數。kwargs (Optional[Dict[str, Argument]])
– 傳遞給被調用函數的關鍵字參數。type_expr (Optional[Any])
– 可選的類型注解,表示該節點輸出值的Python類型。
返回
新創建并插入的call_function
節點。
返回類型
Node
注意:此方法的插入點和類型表達式規則與Graph.create_node()
相同。
注意:此API保證向后兼容性。
call_method(method_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一個call_method
節點。call_method
節點表示對args
第0個元素調用指定方法。
參數
method_name (str)
- 要應用于self參數的方法名稱。例如,如果args[0]是一個表示Tensor
的Node
,那么要對該Tensor
調用relu()
方法時,需將relu
作為method_name
傳入。args (Optional[Tuple[Argument*, ...]])
- 要傳遞給被調用方法的位置參數。注意這應該包含一個self參數。kwargs (Optional[Dict[str, Argument]])
- 要傳遞給被調用方法的關鍵字參數type_expr (Optional[Any])
- 可選的類型注解,表示該節點輸出結果的Python類型。
返回
新創建并插入的call_method
節點。
返回類型
Node
注意:本方法的插入點和類型表達式規則與Graph.create_node()
相同。
注意:此API保證向后兼容性。
call_module(module_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一個call_module
類型的Node
節點。call_module
節點表示對Module
層級結構中某個Module
的forward()函數的調用。
參數
module_name (str)
- 要調用的Module
在層級結構中的限定名稱。例如,若被追蹤的Module
有一個名為foo
的子模塊,而該子模塊又包含名為bar
的子模塊,則應以foo.bar
作為module_name
來調用該模塊。args (Optional[Tuple[Argument*, ...]])
- 傳遞給被調用方法的位置參數。注意:此處不應包含self
參數。kwargs (Optional[Dict[str, Argument]])
- 傳遞給被調用方法的關鍵字參數type_expr (Optional[Any])
- 可選類型注解,表示該節點輸出值的Python類型。
返回
新創建并插入的call_module
節點。
返回類型:Node
注意:本方法的插入點與類型表達式規則與Graph.create_node()
相同。
注意:本API保證向后兼容性。
create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)
創建一個 Node
并將其添加到當前插入點的 Graph
中。
注意:當前插入點可以通過 Graph.inserting_before()
和 Graph.inserting_after()
進行設置。
參數
op (str)
- 該節點的操作碼。可選值包括 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。這些操作碼的語義在Graph
的文檔字符串中有詳細說明。args (Optional[Tuple[Argument*, ...]])
- 該節點的參數元組。kwargs (Optional[Dict[str, Argument]])
- 該節點的關鍵字參數。name (Optional[str])
- 為Node
指定的可選字符串名稱。這將影響生成的 Python 代碼中賦值給該節點的變量名。type_expr (Optional[Any])
- 可選類型注解,表示該節點輸出值的 Python 類型。
返回
新創建并插入的節點。
返回類型:Node
注意:此 API 保證向后兼容。
eliminate_dead_code(is_impure_node=None)
根據圖中各節點的用戶數量及是否具有副作用,移除所有死代碼。調用前必須確保圖已完成拓撲排序。
參數
is_impure_node (Optional[Callable[[Node],* [bool]]])
—— 用于判斷節點是否為非純函數的回調函數。若未提供該參數,則默認使用Node.is_impure
方法。
返回值:返回布爾值,表示該過程是否導致圖結構發生變更。
返回類型:bool
示例
在消除死代碼前,下方表達式 a = x + 1
中的變量 a
無用戶引用,因此可從圖中安全移除而不影響結果。
def forward(self, x):a = x + 1return x + self.attr_1
消除死代碼后,a = x + 1
已被移除,前向傳播部分的其他代碼保留不變。
def forward(self, x):return x + self.attr_1
警告:死代碼消除機制雖然采用了一些啟發式方法來避免刪除具有副作用的節點(參見 Node.is_impure
),但總體覆蓋率非常不理想。因此,除非你明確知道當前 FX 計算圖完全由無副作用的操作構成,或者自行提供了檢測副作用節點的自定義函數,否則不應假設調用此方法是安全可靠的。
注意:本 API 保證向后兼容性。
erase_node(to_erase)
從Graph
中刪除一個Node
。如果該節點在Graph
中仍被使用,將拋出異常。
參數
to_erase (Node)
– 要從Graph
中刪除的Node
。
注意:此API保證向后兼容性。
find_nodes(*, op, target=None, sort=True)
支持快速查詢節點
參數
op (str)
– 操作名稱target (Optional[Target])
– 節點目標。對于call_function操作,target為必填項;其他操作中target為可選參數。sort ([bool])
– 是否按節點在圖中出現的順序返回結果。
返回值:返回符合指定op和target條件的節點迭代器。
警告:此API為實驗性質,且不保證向后兼容。
get_attr(qualified_name, type_expr=None)
向圖中插入一個 get_attr
節點。get_attr
類型的 Node
表示從 Module
層次結構中獲取某個屬性。
參數
qualified_name (str)
- 要獲取屬性的全限定名稱。例如,若被追蹤的 Module 包含名為foo
的子模塊,該子模塊又包含名為bar
的子模塊,而bar
擁有名為baz
的屬性,則應將全限定名稱foo.bar.baz
作為qualified_name
傳入。type_expr (Optional[Any])
- 可選的類型注解,用于表示該節點輸出值的 Python 類型。
返回
新創建并插入的 get_attr
節點。
返回類型:Node
注意:本方法的插入點與類型表達式規則與 Graph.create_node
方法保持一致。
注意:此 API 保證向后兼容性。
graph_copy(g, val_map, return_output_node=False)
將給定圖中的所有節點復制到 self
中。
參數
g (Graph)
– 作為節點復制來源的原始圖。val_map (Dict[Node,* Node])
– 用于存儲節點映射關系的字典,鍵為g
中的節點,值為self
中的對應節點。注意:val_map
可預先包含值以實現特定值的復制覆蓋。
返回值:如果 g
包含輸出節點,則返回 self
中與 g
輸出值等效的值;否則返回 None
。
返回類型:Optional[Union [tuple [Argument, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , [dtype](tensor_attributes.html#torch.dtype "torch.dtype"), Tensor , device , memory_format , layout , OpOverload, [SymInt](torch.html#torch.SymInt "torch.SymInt"), SymBool , SymFloat ]]
注意:本API保證向后兼容性。
inserting_after(n=None)
設置 create_node
及相關方法在圖中插入節點的位置。當在 with
語句中使用時,這會臨時設置插入點,并在 with
語句退出時恢復原位置。
with g.inserting_after(n):... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) # set the insert point permanently
參數:
n (可選[Node]): 要在其之前插入的節點。如果為None,則會在整個圖的起始位置之后插入。
返回:
一個資源管理器,它會在__exit__
時恢復插入點。
注意:此API保證向后兼容性。
inserting_before(n=None)
設置 create_node
及相關方法在圖中插入節點的基準位置。當在 with
語句中使用時,這將臨時設置插入點,并在 with
語句退出時恢復原位置。
with g.inserting_before(n):... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) # set the insert point permanently
參數:
n (Optional[Node]): 要插入位置的前一個節點。如果為None,則會在整個圖的起始位置前插入。
返回:
一個資源管理器,該管理器會在__exit__
時恢復插入點。
注意:此API保證向后兼容性。
lint()
對該圖執行多項檢查以確保其結構正確。具體包括:
- 檢查節點是否具有正確的所有權(由本圖所有)
- 檢查節點是否按拓撲順序排列
- 若該圖擁有所屬的GraphModule,則檢查目標是否存在該GraphModule中
注:本API保證向后兼容性。
node_copy(node, arg_transform=<function Graph.<lambda>>)
將節點從一個圖復制到另一個圖中。arg_transform
需要將節點所在圖的參數轉換為目標圖(self)的參數。示例:
# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}for node in g.nodes:value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
參數
node (Node)
– 要復制到self
中的節點。arg_transform (Callable[[Node], Argument])
– 一個函數,用于將節點args
和kwargs
中的Node
參數轉換為self
中的等效參數。最簡單的情況下,該函數應從原始圖中節點到self
的映射表中檢索值。
返回類型:Node
注意:此 API 保證向后兼容性。
property nodes: _node_list
獲取構成該圖的所有節點列表。
請注意,這個Node
列表是以雙向鏈表的形式表示的。在迭代過程中進行修改(例如刪除節點、添加節點)是安全的。
返回值:一個雙向鏈表結構的節點列表。注意可以對該列表調用reversed
方法來切換迭代順序。
on_generate_code(make_transformer)
在生成 Python 代碼時注冊轉換器函數
參數:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):返回待注冊代碼轉換器的函數。
該函數由 on_generate_code 調用以獲取代碼轉換器。
此函數的輸入參數為當前已注冊的代碼轉換器(若未注冊則為 None),以便在不需要覆蓋時使用。該機制可用于串聯多個代碼轉換器。
返回值:一個上下文管理器,當在 with 語句中使用時,會自動恢復先前注冊的代碼轉換器。
示例:
gm: fx.GraphModule = ...# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):return ["import pdb; pdb.set_trace()\n", body]# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(lambda current_trans: (lambda body: insert_pdb(current_trans(body) if current_trans else body))
)gm.recompile()
gm(inputs) # drops into pdb
該函數也可作為上下文管理器使用,其優勢在于能自動恢復之前注冊的代碼轉換器。
# ... continue from previous examplewith gm.graph.on_generate_code(lambda _: insert_pdb):# do more stuff with `gm`...gm.recompile()gm(inputs) # drops into pdb# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告:此 API 為實驗性質,且不向后兼容。
output(result, type_expr=None)
將 output
Node
插入到 Graph
中。output
節點代表 Python 代碼中的 return
語句。result
是應當返回的值。
參數
result (Argument)
– 要返回的值。type_expr (Optional[Any])
– 可選的類型注解,表示此節點輸出將具有的 Python 類型。
注意:此方法的插入點和類型表達式規則與 Graph.create_node
相同。
注意:此 API 保證向后兼容性。
output_node()
警告:此 API 為實驗性質,且不向后兼容。
返回值類型:Node
placeholder(name, type_expr=None, default_value)
在圖中插入一個placeholder
節點。placeholder
表示函數的輸入參數。
參數
name (str)
- 輸入值的名稱。這對應于該Graph
所表示函數的位置參數名稱。type_expr (Optional[Any])
- 可選的類型注解,表示該節點輸出值的Python類型。在某些情況下(例如當函數后續用于TorchScript編譯時),這是生成正確代碼所必需的。default_value (Any)
- 該函數參數的默認值。注意:為了允許None作為默認值,當參數沒有默認值時,應傳遞inspect.Signature.empty來指定。
返回類型:Node
注意:此方法的插入點和類型表達式規則與Graph.create_node
相同。
注意:此API保證向后兼容性。
print_tabular()
以表格形式打印圖的中間表示。注意:此API需要安裝tabulate
模塊。
注:該API保證向后兼容性。
process_inputs(*args)
處理參數以便它們可以傳遞到 FX 計算圖中。
警告:此 API 為實驗性質,且不向后兼容。
process_outputs(out)
警告:此 API 為實驗性質,且不向后兼容。
python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)
將這段Graph
轉換為有效的Python代碼。
參數
root_module (str)
– 用于查找限定名稱目標的根模塊名稱。通常為’self’。
返回值:src: 表示該對象的Python源代碼
globals: 包含src中全局名稱及其引用對象的字典
返回類型:一個包含兩個字段的PythonCode對象
注意:此API保證向后兼容性。
set_codegen(codegen)
警告:此 API 為實驗性功能,且不向后兼容。
class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)
Node
是表示 Graph
中單個操作的數據結構。在大多數情況下,Node 表示對各種實體的調用點,例如運算符、方法和模塊(某些例外包括指定函數輸入和輸出的節點)。每個 Node
都有一個由其 op
屬性指定的函數。不同 op
值的 Node
語義如下:
placeholder
表示函數輸入。name
屬性指定該值的名稱。target
同樣是參數的名稱。args
包含:1) 空值,或 2) 表示函數輸入默認參數的單個參數。kwargs
無關緊要。占位符對應于圖形輸出中的函數參數(例如x
)。get_attr
從模塊層次結構中檢索參數。name
同樣是獲取結果后賦值的名稱。target
是參數在模塊層次結構中的完全限定名稱。args
和kwargs
無關緊要。call_function
將自由函數應用于某些值。name
同樣是賦值目標的名稱。target
是要應用的函數。args
和kwargs
表示函數的參數,遵循 Python 調用約定。call_module
將模塊層次結構中的forward()
方法應用于給定參數。name
同前。target
是要調用的模塊在模塊層次結構中的完全限定名稱。args
和kwargs
表示調用模塊時的參數(不包括 self 參數*)。call_method
調用值的方法。name
類似。target
是要應用于self
參數的方法名稱字符串。args
和kwargs
表示調用模塊時的參數(包括 self 參數*)。output
在其args[0]
屬性中包含跟蹤函數的輸出。這對應于圖形輸出中的 “return” 語句。
注意:此 API 保證向后兼容。
property all_input_nodes: list ['Node']
Return all Nodes that are inputs to this Node. This is equivalent to iterating over `args` and `kwargs` and only collecting the values that are Nodes.Returns
List of `Nodes` that appear in the `args` and `kwargs` of this `Node`, in that order.append(x)
在圖的節點列表中,將 x
插入到當前節點之后。
等價于調用 self.next.prepend(x)
參數
x (Node)
– 要插入到當前節點后的節點。必須屬于同一個圖。
注意:此 API 保證向后兼容。
property args: tuple [Union [tuple ['Argument',
...], collections.abc.Sequence ['Argument'], collections.abc.Mapping[str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType],
...]
該Node
的參數元組。參數的具體含義取決于節點的操作碼(opcode)。更多信息請參閱Node
文檔字符串。
允許對此屬性進行賦值操作。所有關于使用情況和用戶的記錄都會在賦值時自動更新。
format_node(placeholder_names=None, maybe_return_typename=None)
返回一個描述性的字符串表示形式self
。
該方法可不帶參數使用,作為調試工具。
此函數也用于Graph
的__str__
方法內部。placeholder_names
和maybe_return_typename
中的字符串共同構成了該Graph所屬GraphModule中自動生成的forward
函數的簽名。placeholder_names
和maybe_return_typename
不應在其他情況下使用。
參數
placeholder_names (Optional[list[str]])
- 一個列表,用于存儲表示生成的forward
函數中占位符的格式化字符串。僅供內部使用。maybe_return_typename (Optional[list[str]])
- 一個單元素列表,用于存儲表示生成的forward
函數輸出的格式化字符串。僅供內部使用。
返回
如果1)我們在Graph
的__str__
方法中將format_node
用作內部輔助工具,且2)self
是一個占位符Node,則返回None
。否則,返回當前Node的描述性字符串表示形式。
返回類型:str
注意:此API保證向后兼容。
insert_arg(idx, arg)
在參數列表的指定索引位置插入一個位置參數。
參數
idx ( int )
– 要插入到self.args
中元素之前的索引位置。arg (Argument)
– 要插入到args
中的新參數值
注意:本API保證向后兼容性。
is_impure()
返回該操作是否為不純操作,即判斷其操作是否為占位符或輸出,或者是否為不純的call_function
或call_module
。
返回值:指示該操作是否不純。
返回類型:bool
警告:此API為實驗性質,且不向后兼容。
property kwargs: dict[str , Union [tuple ['Argument',
...], collections.abc.Sequence['Argument'], collections.abc.Mapping, [str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]
該Node
的關鍵字參數字典。參數的解析取決于節點的操作碼。更多信息請參閱Node
文檔字符串。
允許對此屬性進行賦值。所有關于使用情況和用戶的統計都會在賦值時自動更新。
property next: Node
返回鏈表中下一個Node
節點。
返回值:鏈表中下一個Node
節點。
normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)
返回經過標準化的Python目標參數。這意味著當normalize_to_only_use_kwargs
為真時,args/kwargs將與模塊/函數的簽名匹配,并按位置順序僅返回kwargs。
同時會填充默認值。不支持僅限位置參數或可變參數。
支持模塊調用。
可能需要arg_types
和kwarg_types
來消除重載歧義。
參數
root (torch.nn.Module)
– 用于解析模塊目標的基模塊arg_types (Optional[Tuple[Any]])
– 參數的元組類型kwarg_types (Optional[Dict[str, Any]])
– 關鍵字參數的字典類型normalize_to_only_use_kwargs ([bool])
– 是否標準化為僅使用kwargs
返回
返回命名元組ArgsKwargsPair
,若失敗則返回None
返回類型
Optional[ArgsKwargsPair]
警告:該API為實驗性質,不保證向后兼容。
prepend(x)
在圖的節點列表中,在此節點前插入x。示例:
Before: p -selfbx -x -ax
After: p -x -selfbx -ax
參數
x (Node)
– 要放置在該節點之前的節點。必須是同一圖的成員。
注意:此 API 保證向后兼容。
property prev: Node
返回鏈表中當前節點的前一個Node
。
返回值:鏈表中當前節點的前一個Node
。
replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)
將圖中所有使用 self
的地方替換為節點 replace_with
。
參數
replace_with (Node)
– 用于替換所有self
的節點。delete_user_cb (Callable)
– 回調函數,用于判斷是否應移除某個使用原self
節點的用戶節點。propagate_meta ([bool])
– 是否將原節點.meta
字段的所有屬性復制到替換節點上。出于安全考慮,僅當替換節點本身沒有.meta
字段時才允許此操作。
返回值
返回受此變更影響的節點列表。
返回類型:list [Node]
注意:此 API 保證向后兼容。
replace_input_with(old_input, new_input)
遍歷 self
的輸入節點,將所有 old_input
實例替換為 new_input
。
參數
old_input (Node)
– 需要被替換的舊輸入節點。new_input (Node)
– 用于替換old_input
的新輸入節點。
注意:此 API 保證向后兼容性。
property stack_trace: Optional[str ]
返回在追蹤過程中記錄的 Python 堆棧跟蹤信息(如果有)。
當使用 fx.Tracer
進行追蹤時,該屬性通常由 Tracer.create_proxy
填充。若需在追蹤過程中記錄堆棧跟蹤以用于調試,請在 Tracer 實例上設置 record_stack_traces = True
。
當使用 dynamo 進行追蹤時,該屬性默認會由 OutputGraph.create_proxy
填充。
stack_trace
的字符串末尾將包含最內層的調用幀。
update_arg(idx, arg)
更新現有位置參數以包含新值
調用后,self.args[idx] == arg
將成立。
參數
idx ( int )
- 要更新元素在self.args
中的索引位置arg (Argument)
- 要寫入args
的新參數值
注意:此 API 保證向后兼容性。
update_kwarg(key, arg)
更新現有關鍵字參數以包含新值
arg
。調用后,self.kwargs[key] == arg
。
參數
key (str)
- 要更新的元素在self.kwargs
中的鍵名arg (Argument)
- 要寫入kwargs
的新參數值
注意:此API保證向后兼容性。
class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())
Tracer
是實現 torch.fx.symbolic_trace
符號追蹤功能的類。調用 symbolic_trace(m)
等價于執行 Tracer().trace(m)
。
可以通過繼承 Tracer 類來覆蓋追蹤過程中的各種行為。具體可覆蓋的行為詳見該類方法的文檔字符串。
注意:此 API 保證向后兼容。
call_module(m, forward, args, kwargs)
該方法定義了當Tracer
遇到對nn.Module
實例調用時的行為。
默認行為是通過is_leaf_module
檢查被調用的模塊是否為葉子模塊。如果是,則在Graph
中生成指向m
的call_module
節點;否則正常調用該Module
,并跟蹤其forward
函數中的操作。
可通過重寫此方法實現自定義行為,例如:
- 創建嵌套的追蹤GraphModules
- 實現跨
Module
邊界追蹤時的特殊處理
參數說明:
m (Module)
- 當前被調用的模塊實例forward (Callable)
- 待調用模塊的forward()方法args (Tuple)
- 模塊調用點的參數元組kwargs (Dict)
- 模塊調用點的關鍵字參數字典
返回值:
- 若生成
call_module
節點,則返回Proxy
代理值 - 否則返回模塊調用的原始結果
返回類型:任意類型
注意:本API保證向后兼容性。
create_arg(a)
一種方法,用于指定在準備值作為Graph
中節點的參數時追蹤的行為。
默認行為包括:
1、遍歷集合類型(如元組、列表、字典)并遞歸地對元素調用create_args
。
2、給定一個Proxy對象,返回底層IR Node
的引用。
3、給定一個非Proxy的Tensor對象,為以下情況生成IR:
-
對于Parameter,生成一個引用該Parameter的
get_attr
節點。 -
對于非Parameter的Tensor,將該Tensor存儲在一個特殊屬性中,并引用該屬性。
可以重寫此方法以支持更多類型。
參數
a (Any)
– 將被作為Argument
在Graph
中使用的值。
返回值:將值a
轉換為適當的Argument
。
返回類型:Argument
注意:此API保證向后兼容。
create_args_for_root(root_fn, is_module, concrete_args=None)
為root
模塊的簽名創建對應的placeholder
節點。該方法會檢查root模塊的簽名并據此生成這些節點,同時支持*args
和**kwargs
參數。
警告:此API為實驗性質,且不向后兼容。
create_node(kind, target, args, kwargs, name=None, type_expr=None)
根據給定的目標、參數、關鍵字參數和名稱插入一個圖節點。
該方法可以被重寫,用于在節點創建過程中對使用的值進行額外檢查、驗證或修改。例如,可能希望禁止記錄原地操作。
注意:此API保證向后兼容性。
返回類型:Node
create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)
根據給定的參數創建一個節點,然后返回包裹在 Proxy 對象中的節點。
如果 kind = ‘placeholder’,則表示我們正在創建一個代表函數參數的節點。若需要編碼默認參數,則使用 args
元組。對于 placeholder
類型的節點,args
在其他情況下為空。
注意:此 API 保證向后兼容性。
get_fresh_qualname(prefix)
獲取一個基于前綴的新名稱并返回。該函數確保生成的名稱不會與圖中現有屬性發生沖突。
注意:此API保證向后兼容。
返回類型:str
getattr(attr, attr_val, parameter_proxy_cache)
該方法定義了當對nn.Module
實例調用getattr時,該Tracer
的行為表現。
默認情況下,其行為是返回該屬性的代理值。同時會將代理值存入parameter_proxy_cache
中,以便后續調用能復用該代理而非新建。
可通過重寫此方法來實現不同行為——例如在查詢參數時不返回代理。
參數說明:
attr (str)
- 被查詢的屬性名稱attr_val (Any)
- 該屬性的值parameter_proxy_cache (Dict[str, Any])
- 屬性名到代理值的映射緩存
返回值:
getattr調用的返回結果。
警告:此API屬于實驗性質,且不保證向后兼容。
is_leaf_module(m, module_qualified_name)
一種用于判斷給定nn.Module
是否為"葉子"模塊的方法。
葉子模塊是指出現在IR(中間表示)中的原子單元,通過call_module
調用進行引用。默認情況下,PyTorch標準庫命名空間(torch.nn)中的模塊都屬于葉子模塊。除非通過本參數特別指定,否則其他所有模塊都會被追蹤并記錄其組成操作。
參數說明:
m (Module)
- 被查詢的模塊module_qualified_name (str)
- 該模塊到根模塊的路徑。例如,若模塊層級結構中子模塊foo
包含子模塊bar
,而bar
又包含子模塊baz
,則該模塊的限定名將顯示為foo.bar.baz
返回類型:bool
注意:本API保證向后兼容性。
iter(obj)
當代理對象被迭代時調用,例如在控制流中使用時。通常我們不知道如何處理,因為我們不知道代理的值,但自定義跟蹤器可以通過 create_node
向圖節點附加更多信息,并可以選擇返回一個迭代器。
注意:此 API 保證向后兼容性。
返回類型:迭代器
keys(obj)
當代理對象的 keys()
方法被調用時觸發。這是在代理對象上調用 **
時發生的情況。該方法應返回一個迭代器,如果 **
需要在自定義追蹤器中生效。
注意:此 API 保證向后兼容。
返回類型:任意
path_of_module(mod)
這是一個輔助方法,用于在root
模塊的層級結構中查找mod
的限定名稱。例如,如果root
有一個名為foo
的子模塊,而foo
又有一個名為bar
的子模塊,那么將bar
傳入此函數將返回字符串"foo.bar"。
參數
mod (str)
– 需要獲取限定名稱的Module
。
返回類型:str
注意:此API保證向后兼容性。
proxy(node)
注意:此 API 保證向后兼容性。
返回類型:Proxy
to_bool(obj)
當代理對象需要轉換為布爾值時調用,例如在控制流中使用時。通常我們無法確定如何處理,因為不知道代理的具體值,但自定義追蹤器可以通過create_node
向圖節點附加更多信息,并選擇返回一個值。
注意:此API保證向后兼容。
返回類型:bool
trace(root, concrete_args=None)
追蹤 root
并返回對應的 FX Graph
表示形式。root
可以是 nn.Module
實例或 Python 可調用對象。
請注意,在此調用后,self.root
可能與傳入的 root
不同。例如,當向 trace()
傳遞自由函數時,我們會創建一個 nn.Module
實例作為根節點,并添加嵌入的常量。
參數
root (Union[Module, Callable])
– 需要追蹤的Module
或函數。該參數保證向后兼容性。concrete_args (Optional[Dict[str, any]])
– 不應被視為代理的具體參數。此參數為實驗性功能,其向后兼容性不作保證。
返回值:表示傳入 root
語義的 Graph
對象。
返回類型:Graph
注意:此 API 保證向后兼容性。
class torch.fx.Proxy(node, tracer=None)
Proxy
對象是Node
包裝器,在符號追蹤過程中流經程序,并記錄它們接觸到的所有操作(包括torch
函數調用、方法調用和運算符)到不斷增長的FX Graph中。
如果需要進行圖變換,您可以在原始Node
上封裝自己的Proxy
方法,這樣就可以使用重載運算符向Graph
添加額外內容。
Proxy
對象不可迭代。換句話說,如果在循環中或作為*args
/**kwargs
函數參數使用Proxy
,符號追蹤器會拋出錯誤。
有兩種主要解決方法:
1、將不可追蹤的邏輯提取到頂層函數中,并使用fx.wrap
進行處理。
2、如果控制流是靜態的(即循環次數基于某些超參數),可以保持代碼在原位,并重構為類似形式:
for i in range(self.some_hyperparameter):indexed_item = proxied_value[i]
如需更深入了解 Proxy 的內部實現細節,請查閱 torch/fx/README.md 文件中的 “Proxy” 章節。
注意:本 API 保證向后兼容性。
class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)
解釋器(Interpreter)會逐節點(Node-by-Node)執行FX圖。這種模式在許多場景下非常有用,包括編寫代碼轉換器以及分析過程。
通過重寫Interpreter類中的方法,可以自定義執行行為。以下是按調用層次結構劃分的可重寫方法映射:
run()+-- run_node+-- placeholder()+-- get_attr()+-- call_function()+-- call_method()+-- call_module()+-- output()
示例
假設我們需要將所有 torch.neg
實例與 torch.sigmoid
互換(包括它們對應的 Tensor
方法等價形式)。我們可以通過如下方式繼承 Interpreter 類:
class NegSigmSwapInterpreter(Interpreter):def call_function(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == torch.sigmoid:return torch.neg(args, *kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == "neg":call_self, args_tail = argsreturn call_self.sigmoid(args_tail, *kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
參數
module ( torch.nn.Module )
– 待執行的模塊garbage_collect_values ([bool])
– 是否在模塊執行過程中最后一次使用后刪除值。這能確保執行期間內存使用最優。可以禁用此功能,例如通過查看Interpreter.env
屬性來檢查執行中的所有中間值。graph (Optional[Graph])
– 如果傳入該參數,解釋器將執行此圖而非module.graph,并使用提供的模塊參數來滿足任何狀態請求。
注意:此API保證向后兼容性。
boxed_run(args_list)
通過解釋方式運行模塊并返回結果。該過程采用"boxed"調用約定,即傳遞一個參數列表(這些參數會被解釋器自動清除),從而確保輸入張量能夠及時釋放。
注意:本API保證向后兼容性。
call_function(target, args, kwargs)
執行一個call_function
節點并返回結果。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱Nodeargs (Tuple)
– 本次調用的位置參數元組kwargs (Dict)
– 本次調用的關鍵字參數字典
返回類型:任意類型
返回值: 函數調用返回的值
注意:此API保證向后兼容性。
call_method(target, args, kwargs)
執行一個 call_method
節點并返回結果。
參數
target (Target)
– 該節點的調用目標。有關語義的詳細信息,請參閱 Nodeargs (Tuple)
– 該調用的位置參數元組kwargs (Dict)
– 該調用的關鍵字參數字典
返回類型:任意
返回值:方法調用返回的值
注意:此 API 保證向后兼容性。
call_module(target, args, kwargs)
執行一個call_module
節點并返回結果。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱
Nodeargs (Tuple)
– 本次調用的位置參數元組kwargs (Dict)
– 本次調用的關鍵字參數字典
返回類型:Any
返回值:模塊調用返回的值
注意:此API保證向后兼容性。
fetch_args_kwargs_from_env(n)
從當前執行環境中獲取節點n
的args
和kwargs
具體值
參數
n (Node)
– 需要獲取args
和kwargs
的目標節點
返回值
節點n
對應的具體args
和kwargs
值
返回類型:Tuple[Tuple, Dict]
注意:本API保證向后兼容性
fetch_attr(target)
從 self.module
的 Module
層級結構中獲取一個屬性。
參數
target (str)
- 要獲取屬性的全限定名稱
返回
該屬性的值。
返回類型
任意類型
注意:此 API 保證向后兼容。
get_attr(target, args, kwargs)
執行一個 get_attr
節點。該操作會從 self.module
的 Module
層級結構中獲取屬性值。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱 Nodeargs (Tuple)
– 本次調用的位置參數元組kwargs (Dict)
– 本次調用的關鍵字參數字典
返回值
獲取到的屬性值
返回類型
任意類型
注意:此 API 保證向后兼容性。
map_nodes_to_values(args, n)
遞歸遍歷 args
并在當前執行環境中查找每個 Node
的具體值。
參數
args (Argument)
– 需要查找具體值的數據結構n (Node)
–args
所屬的節點。僅用于錯誤報告。
返回類型:Optional[Union [tuple [Argument’, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , dtype, Tensor , device , memory_format , layout , OpOverload, SymInt, SymBool , SymFloat ]]
注意:此 API 保證向后兼容性。
output(target, args, kwargs)
執行一個output
節點。該操作實際上只是獲取output
節點引用的值并返回它。
參數
target (Target)
– 該節點的調用目標。有關語義詳情請參閱
Nodeargs (Tuple)
– 本次調用的位置參數元組kwargs (Dict)
– 本次調用的關鍵字參數字典
返回值:輸出節點引用的返回值
返回類型:任意類型
注意:此API保證向后兼容。
placeholder(target, args, kwargs)
執行一個placeholder
節點。請注意這是有狀態的:
Interpreter
內部維護了一個針對run
方法傳入參數的迭代器,本方法會返回該迭代器的next()結果。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱Nodeargs (Tuple)
– 本次調用的位置參數元組kwargs (Dict)
– 本次調用的關鍵字參數字典
返回值:獲取到的參數值。
返回類型:任意類型
注意:此API保證向后兼容。
run(*args, initial_env=None, enable_io_processing=True)
通過解釋執行模塊并返回結果。
參數
*args
– 按位置順序傳遞給模塊的運行參數initial_env (Optional[Dict[Node, Any]])
– 可選的執行初始環境。這是一個將節點映射到任意值的字典。例如,可用于預先填充某些節點的結果,從而在解釋器中僅進行部分求值。enable_io_processing ([bool])
– 如果為true,我們會在使用輸入和輸出之前,先用圖的process_inputs和process_outputs函數對它們進行處理。
返回值:執行模塊后返回的值
返回類型:任意
注意:此API保證向后兼容。
run_node(n)
運行特定節點 n
并返回結果。
根據 node.op
的類型,調用對應的占位符、get_attr、call_function、call_method、call_module 或 output 方法。
參數
n (Node)
– 需要執行的節點
返回值:執行節點 n
的結果
返回類型:任意類型
注意:此 API 保證向后兼容性。
class torch.fx.Transformer(module)
Transformer
是一種特殊類型的解釋器,用于生成新的 Module
。它提供了一個 transform()
方法,返回轉換后的 Module
。與 Interpreter
不同,Transformer
不需要參數即可運行,完全基于符號化方式工作。
示例
假設我們需要將所有 torch.neg
實例與 torch.sigmoid
互換(包括它們的 Tensor
方法等效形式)。可以通過如下方式子類化 Transformer
:
class NegSigmSwapXformer(Transformer):def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == torch.sigmoid:return torch.neg(*args, **kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == "neg":call_self, *args_tail = argsreturn call_self.sigmoid(*args_tail, **kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
參數
module ([GraphModule](https://pytorch.org/docs/stable/data.html#torch.fx.GraphModule "torch.fx.GraphModule"))
– 待轉換的Module
對象。
注意:此API保證向后兼容性。
call_function(target, args, kwargs)
注意:該 API 保證向后兼容。
返回類型
Any
call_module(target, args, kwargs)
注意:此 API 保證向后兼容。
返回類型
Any
get_attr(target, args, kwargs)
執行一個 get_attr
節點。在 Transformer
中,該方法被重寫以便向輸出圖中插入新的 get_attr
節點。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱
Nodeargs (Tuple)
– 該調用的位置參數元組kwargs (Dict)
– 該調用的關鍵字參數字典
返回類型
Proxy
注意:此 API 保證向后兼容。
placeholder(target, args, kwargs)
執行一個 placeholder
節點。在 Transformer
中,該方法被重寫以便向輸出圖中插入新的 placeholder
。
參數
target (Target)
– 該節點的調用目標。關于語義的詳細信息請參閱 Nodeargs (Tuple)
– 該調用的位置參數元組kwargs (Dict)
– 該調用的關鍵字參數字典
返回類型:Proxy
注意:此 API 保證向后兼容。
transform()
轉換 self.module
并返回轉換后的 GraphModule
。
注意:此 API 保證向后兼容性。
返回類型 : GraphModule
torch.fx.replace_pattern(gm, pattern, replacement)
在GraphModule的圖結構(gm
)中,匹配所有可能的非重疊運算符集及其數據依賴關系(pattern
),然后將每個匹配到的子圖替換為另一個子圖(replacement
)。
參數
gm (GraphModule)
- 封裝待操作圖的GraphModulepattern (Union[Callable, GraphModule])
- 需要在gm
中匹配并替換的子圖replacement (Union[Callable, GraphModule])
- 用于替換pattern
的子圖
返回值:返回一個Match
對象列表,表示原始圖中與pattern
匹配的位置。如果沒有匹配項則返回空列表。Match
定義如下:
class Match(NamedTuple):# Node from which the match was foundanchor: Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map: Dict[Node, Node]
返回類型:List[Match]
示例:
import torch
from torch.fx import symbolic_trace, subgraph_rewriterclass M(torch.nn.Module):def __init__(self) -None:super().__init__()def forward(self, x, w1, w2):m1 = torch.cat([w1, w2]).sum()m2 = torch.cat([w1, w2]).sum()return x + torch.max(m1) + torch.max(m2)def pattern(w1, w2):return torch.cat([w1, w2])def replacement(w1, w2):return torch.stack([w1, w2])traced_module = symbolic_trace(M())subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代碼會先在 traced_module
的 forward
方法中匹配 pattern
。模式匹配基于使用-定義關系而非節點名稱進行。例如,若 pattern
中包含 p = torch.cat([a, b])
,則可以在原始 forward
函數中匹配到 m = torch.cat([a, b])
,即使變量名不同(p
與 m
)也不影響。
pattern
中的 return
語句僅根據其值進行匹配,它可能與更大圖中的 return
語句匹配,也可能不匹配。換句話說,模式不必延伸至更大圖的末尾。
當模式匹配成功時,它將從更大的函數中被移除,并由 replacement
替換。如果更大函數中存在多個 pattern
匹配項,每個非重疊的匹配項都會被替換。若出現匹配重疊的情況,則替換重疊匹配集中最先找到的匹配項(此處的"最先"定義為節點使用-定義關系拓撲排序中的第一個節點。大多數情況下,第一個節點是緊接 self
后出現的參數,而最后一個節點是函數返回的內容)。
需要特別注意:pattern
可調用對象的參數必須在該可調用對象內部使用,且 replacement
可調用對象的參數必須與模式匹配。第一條規則解釋了為何上述代碼塊中 forward
函數有參數 x, w1, w2
,而 pattern
函數只有參數 w1, w2
——因為 pattern
未使用 x
,故不應將 x
指定為參數。
關于第二條規則的示例,考慮替換…
def pattern(x, y):return torch.neg(x) + torch.relu(y)
with
def replacement(x, y):return torch.relu(x)
在這種情況下,replacement
需要與pattern
相同數量的參數(包括x
和y
),即使參數y
在replacement
中并未使用。
調用subgraph_rewriter.replace_pattern
后,生成的Python代碼如下所示:
def forward(self, x, w1, w2):stack_1 = torch.stack([w1, w2])sum_1 = stack_1.sum()stack_2 = torch.stack([w1, w2])sum_2 = stack_2.sum()max_1 = torch.max(sum_1)add_1 = x + max_1max_2 = torch.max(sum_2)add_2 = add_1 + max_2return add_2
注意:該 API 保證向后兼容。
torch.fx.experimental
警告:這些API屬于實驗性質,可能會隨時變更而不另行通知。
torch.fx.experimental.symbolic_shapes
ShapeEnv | |
---|---|
DimDynamic | 控制如何為維度分配符號。 |
StrictMinMaxConstraint | 對客戶端:該維度的大小必須在’vr’范圍內(指定包含性上下界),且必須為非負數且不應為0或1(但參見下方注意事項)。 |
RelaxedUnspecConstraint | 對客戶端:無顯式約束;約束由追蹤過程中的守衛隱式推斷得出。 |
EqualityConstraint | 表示并判定輸入源之間的各類相等性約束。 |
SymbolicContext | 數據結構,指定在create_symbolic_sizes_strides_storage_offset 中如何創建符號;例如,應為靜態還是動態。 |
StatelessSymbolicContext | 通過DimDynamic 和DimConstraint 給定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中創建符號。 |
StatefulSymbolicContext | 通過Source:Symbol緩存給定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中創建符號。 |
SubclassSymbolicContext | 可追蹤張量子類的內部張量的正確符號上下文可能與外部符號上下文不同。 |
DimConstraints | 針對符號維度約束系統的自定義求解器。 |
ShapeEnvSettings | 封裝所有可能影響FakeTensor調度的形狀環境設置。 |
ConvertIntKey | |
CallMethodKey | |
PropagateUnbackedSymInts | |
DivideByKey | |
InnerTensorKey | |
hint_int | 獲取整數的提示值(基于運行時觀察到的底層實際值)。 |
is_concrete_int | 檢查SymInt底層對象是否為具體值的實用工具。 |
is_concrete_bool | 檢查SymBool底層對象是否為具體值的實用工具。 |
is_concrete_float | 檢查SymInt底層對象是否為具體值的實用工具。 |
has_free_symbols | bool(free_symbols(val))的快速版本 |
has_free_unbacked_symbols | bool(free_unbacked_symbols(val))的快速版本 |
definitely_true | 僅當能確定a為True時返回True,過程中可能引入守衛。 |
definitely_false | 僅當能確定a為False時返回True,過程中可能引入守衛。 |
guard_size_oblivious | 以無視大小的方式對符號布爾表達式執行守衛。 |
sym_eq | 類似==,但在列表/元組上運行時,會遞歸測試相等性并使用sym_and連接結果,不引入守衛。 |
constrain_range | 應用約束使傳入的SymInt必須在min-max范圍內(包含邊界),且不引入SymInt的守衛(意味著可用于未綁定的SymInt)。 |
constrain_unify | 給定兩個SymInt,約束它們必須相等。 |
canonicalize_bool_expr | 通過將布爾表達式轉換為lt/le不等式并將所有非常量項移至右側,實現規范化。 |
statically_known_true | 如果x可簡化為常量且為真,則返回True。 |
lru_cache | |
check_consistent | 測試兩個"meta"值(通常為Tensor或SymInt)是否具有相同的值,例如在重追蹤后。 |
compute_unbacked_bindings | 在運行fake tensor傳播并生成example_value結果后,遍歷example_value查找新綁定的未支持符號并記錄其路徑供后續使用。 |
rebind_unbacked | 假設我們正在重追蹤一個已有FX圖,該圖先前進行過fake tensor傳播(因此存在未支持的SymInt)。 |
resolve_unbacked_bindings | |
is_accessor_node |
torch.fx.experimental.proxy_tensor
make_fx | 給定函數f,返回一個新函數。當使用有效參數執行該函數時,會返回一個FX GraphModule,表示執行過程中所執行的操作集合。 |
---|---|
handle_sym_dispatch | 調用當前活動的代理跟蹤模式,對操作SymInt/SymFloat/SymBool參數的函數進行符號調度跟蹤。 |
get_proxy_mode | 獲取當前活動的代理跟蹤模式,如果當前未處于跟蹤狀態則返回None。 |
maybe_enable_thunkify | 在此上下文管理器內,如果正在進行make_fx跟蹤,將對所有SymNode計算進行thunkify處理,并避免將其跟蹤到圖中,除非確實需要。 |
maybe_disable_thunkify | 在某個上下文中禁用thunkification功能。 |
2025-05-10(六)